"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "8337978f754030e142123e7360742661bc52c47c"
Unverified Commit 2bd79e23 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[BART] FP16 testing fixes (#3266)

parent 8320feec
...@@ -82,7 +82,7 @@ LARGE_NEGATIVE = -1e8 ...@@ -82,7 +82,7 @@ LARGE_NEGATIVE = -1e8
def _prepare_bart_decoder_inputs( def _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None,
): ):
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if """Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks. none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
...@@ -101,6 +101,8 @@ def _prepare_bart_decoder_inputs( ...@@ -101,6 +101,8 @@ def _prepare_bart_decoder_inputs(
new_shape = (bsz, tgt_len, tgt_len) new_shape = (bsz, tgt_len, tgt_len)
# make it broadcastable so can just be added to the attention coefficients # make it broadcastable so can just be added to the attention coefficients
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device) decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
if mask_dtype is not None:
decoder_attn_mask = decoder_attn_mask.to(mask_dtype)
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len) assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
return decoder_input_ids, decoder_attn_mask return decoder_input_ids, decoder_attn_mask
...@@ -838,7 +840,11 @@ class BartModel(PretrainedBartModel): ...@@ -838,7 +840,11 @@ class BartModel(PretrainedBartModel):
# make masks if user doesn't supply # make masks if user doesn't supply
if not self.decoder.generation_mode: if not self.decoder.generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs( decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask, self.config,
input_ids,
decoder_input_ids=decoder_input_ids,
decoder_attn_mask=decoder_attention_mask,
mask_dtype=self.shared.weight.dtype,
) )
assert decoder_input_ids is not None assert decoder_input_ids is not None
if encoder_outputs is None: if encoder_outputs is None:
......
...@@ -314,10 +314,16 @@ class BartHeadTests(unittest.TestCase): ...@@ -314,10 +314,16 @@ class BartHeadTests(unittest.TestCase):
@unittest.skipIf(torch_device == "cpu", "Cant do half precision") @unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_generate_fp16(self): def test_generate_fp16(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=True) config, input_ids, batch_size = self._get_config_and_data(output_past=True)
input_ids = input_ids attention_mask = input_ids.ne(1).to(torch_device)
model = BartForConditionalGeneration(config).eval().to(torch_device).half()
model.generate(input_ids, attention_mask=attention_mask, do_sample=False, early_stopping=True)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_base_model_fp16(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
attention_mask = input_ids.ne(1).to(torch_device) attention_mask = input_ids.ne(1).to(torch_device)
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half() lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
lm_model.generate(input_ids, attention_mask=attention_mask) lm_model(input_ids, attention_mask=attention_mask)
def test_prepare_bart_decoder_inputs(self): def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data(output_past=False) config, *_ = self._get_config_and_data(output_past=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment