"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "1000bcaeb7d035d692a534280bccf9b710588a94"
Unverified Commit 026a5d08 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[T5 fp16] Fix fp16 in T5 (#4436)

* fix fp16 in t5

* make style

* refactor invert_attention_mask fn

* fix typo
parent fa6113f9
...@@ -149,8 +149,12 @@ class T5LayerNorm(nn.Module): ...@@ -149,8 +149,12 @@ class T5LayerNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, x): def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True) # layer norm should always be calculated in float32
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x / torch.sqrt(variance + self.variance_epsilon) x = x / torch.sqrt(variance + self.variance_epsilon)
if self.weight.dtype == torch.float16:
x = x.to(torch.float16)
return self.weight * x return self.weight * x
...@@ -691,7 +695,9 @@ class T5Stack(T5PreTrainedModel): ...@@ -691,7 +695,9 @@ class T5Stack(T5PreTrainedModel):
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1] encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(inputs_embeds.device) encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
)
# initialize past_key_value_states with `None` if past does not exist # initialize past_key_value_states with `None` if past does not exist
if past_key_value_states is None: if past_key_value_states is None:
...@@ -733,6 +739,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -733,6 +739,7 @@ class T5Stack(T5PreTrainedModel):
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, present_key_value_state = layer_outputs[:2]
if i == 0: if i == 0:
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
......
...@@ -128,7 +128,18 @@ class ModuleUtilsMixin: ...@@ -128,7 +128,18 @@ class ModuleUtilsMixin:
# encoder_extended_attention_mask = (encoder_extended_attention_mask == # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
# encoder_extended_attention_mask.transpose(-1, -2)) # encoder_extended_attention_mask.transpose(-1, -2))
encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
if self.dtype == torch.float16:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
elif self.dtype == torch.float32:
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
else:
raise ValueError(
"{} not recognized. `dtype` should be set to either `torch.float32` or `torch.float16`".format(
self.dtype
)
)
return encoder_extended_attention_mask return encoder_extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device): def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: tuple, device: device):
......
...@@ -304,6 +304,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -304,6 +304,16 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_t5_model_fp16_forward(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config)
model.to(torch_device)
model.half()
model.eval()
output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0]
self.parent.assertFalse(torch.isnan(output).any().item())
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
( (
...@@ -355,6 +365,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -355,6 +365,11 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs) self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_t5_model_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model_fp16_forward(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment