Unverified Commit 0fe17f37 authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

FX tracing improvement (#14321)

* Change the way tracing happens, enabling dynamic axes out of the box

* Update the tests and modeling xlnet

* Add the non recoding of leaf modules to avoid recording more values for the methods to record than what will be seen at tracing time (which would otherwise desynchronize the recorded values and the values that need to be given to the proxies during tracing, causing errors).

* Comments and making tracing work for gpt-j and xlnet

* Refactore things related to num_choices (and batch_size, sequence_length)

* Update fx to work on PyTorch 1.10

* Postpone autowrap_function feature usage for later

* Add copyrights

* Remove unnecessary file

* Fix issue with add_new_model_like

* Apply suggestions
parent 552f8d30
...@@ -1189,6 +1189,16 @@ def create_new_model_like( ...@@ -1189,6 +1189,16 @@ def create_new_model_like(
if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f) if "tokenization" not in str(f) and "processor" not in str(f) and "feature_extraction" not in str(f)
] ]
def disable_fx_test(filename: Path) -> bool:
with open(filename) as fp:
content = fp.read()
new_content = re.sub(r"fx_compatible\s*=\s*True", "fx_compatible = False", content)
with open(filename, "w") as fp:
fp.write(new_content)
return content != new_content
disabled_fx_test = False
for test_file in files_to_adapt: for test_file in files_to_adapt:
new_test_file_name = test_file.name.replace( new_test_file_name = test_file.name.replace(
old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased old_model_patterns.model_lower_cased, new_model_patterns.model_lower_cased
...@@ -1201,6 +1211,13 @@ def create_new_model_like( ...@@ -1201,6 +1211,13 @@ def create_new_model_like(
dest_file=dest_file, dest_file=dest_file,
add_copied_from=False, add_copied_from=False,
) )
disabled_fx_test = disabled_fx_test | disable_fx_test(dest_file)
if disabled_fx_test:
print(
"The tests for symbolic tracing with torch.fx were disabled, you can add those once symbolic tracing works "
"for your new model."
)
# 4. Add model to auto classes # 4. Add model to auto classes
add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes) add_model_to_auto_classes(old_model_patterns, new_model_patterns, model_classes)
......
...@@ -322,7 +322,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOIN ...@@ -322,7 +322,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HUGGINGFACE_CO_RESOLVE_ENDPOIN
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. # This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs.
TORCH_FX_REQUIRED_VERSION = version.parse("1.9") TORCH_FX_REQUIRED_VERSION = version.parse("1.10")
TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8") TORCH_ONNX_DICT_INPUTS_MINIMUM_VERSION = version.parse("1.8")
_is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False _is_offline_mode = True if os.environ.get("TRANSFORMERS_OFFLINE", "0").upper() in ENV_VARS_TRUE_VALUES else False
......
...@@ -247,6 +247,27 @@ class ModuleUtilsMixin: ...@@ -247,6 +247,27 @@ class ModuleUtilsMixin:
return encoder_extended_attention_mask return encoder_extended_attention_mask
def create_extended_attention_mask_for_decoder(self, input_shape, attention_mask, device):
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor: def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device) -> Tensor:
""" """
Makes broadcastable attention and causal masks so that future and masked tokens are ignored. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
...@@ -271,26 +292,9 @@ class ModuleUtilsMixin: ...@@ -271,26 +292,9 @@ class ModuleUtilsMixin:
# - if the model is a decoder, apply a causal mask in addition to the padding mask # - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder: if self.config.is_decoder:
batch_size, seq_length = input_shape extended_attention_mask = self.create_extended_attention_mask_for_decoder(
seq_ids = torch.arange(seq_length, device=device) input_shape, attention_mask, device
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones(
(batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype
),
causal_mask,
],
axis=-1,
) )
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else: else:
extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = attention_mask[:, None, None, :]
else: else:
...@@ -1861,7 +1865,7 @@ class Conv1D(nn.Module): ...@@ -1861,7 +1865,7 @@ class Conv1D(nn.Module):
def forward(self, x): def forward(self, x):
size_out = x.size()[:-1] + (self.nf,) size_out = x.size()[:-1] + (self.nf,)
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out) x = x.view(size_out)
return x return x
......
...@@ -293,7 +293,7 @@ class AlbertAttention(nn.Module): ...@@ -293,7 +293,7 @@ class AlbertAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores # Copied from transformers.models.bert.modeling_bert.BertSelfAttention.transpose_for_scores
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def prune_heads(self, heads): def prune_heads(self, heads):
......
...@@ -252,7 +252,7 @@ class BertSelfAttention(nn.Module): ...@@ -252,7 +252,7 @@ class BertSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -341,7 +341,7 @@ class BertSelfAttention(nn.Module): ...@@ -341,7 +341,7 @@ class BertSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -245,7 +245,7 @@ class ElectraSelfAttention(nn.Module): ...@@ -245,7 +245,7 @@ class ElectraSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -334,7 +334,7 @@ class ElectraSelfAttention(nn.Module): ...@@ -334,7 +334,7 @@ class ElectraSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -193,7 +193,7 @@ class GPT2Attention(nn.Module): ...@@ -193,7 +193,7 @@ class GPT2Attention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights: if self.scale_attn_weights:
attn_weights = attn_weights / (float(value.size(-1)) ** 0.5) attn_weights = attn_weights / (value.size(-1) ** 0.5)
# Layer-wise attention scaling # Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx: if self.scale_attn_by_inverse_layer_idx:
...@@ -281,7 +281,7 @@ class GPT2Attention(nn.Module): ...@@ -281,7 +281,7 @@ class GPT2Attention(nn.Module):
Splits hidden_size dim into attn_head_size and num_heads Splits hidden_size dim into attn_head_size and num_heads
""" """
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape) tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size): def _merge_heads(self, tensor, num_heads, attn_head_size):
...@@ -915,7 +915,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -915,7 +915,7 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape) hidden_states = hidden_states.view(output_shape)
# Add last hidden state # Add last hidden state
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): ...@@ -1410,7 +1410,7 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
) )
pooled_logits = logits[range(batch_size), sequence_lengths] pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -173,7 +173,7 @@ class GPTNeoSelfAttention(nn.Module): ...@@ -173,7 +173,7 @@ class GPTNeoSelfAttention(nn.Module):
Splits hidden_size dim into attn_head_size and num_heads Splits hidden_size dim into attn_head_size and num_heads
""" """
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape) tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size): def _merge_heads(self, tensor, num_heads, attn_head_size):
...@@ -637,7 +637,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -637,7 +637,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape) hidden_states = hidden_states.view(output_shape)
# Add last hidden state # Add last hidden state
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -891,7 +891,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): ...@@ -891,7 +891,7 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
) )
pooled_logits = logits[torch.arange(batch_size), sequence_lengths] pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -107,7 +107,7 @@ class GPTJAttention(nn.Module): ...@@ -107,7 +107,7 @@ class GPTJAttention(nn.Module):
Splits hidden dim into attn_head_size and num_attention_heads Splits hidden dim into attn_head_size and num_attention_heads
""" """
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(*new_shape) tensor = tensor.view(new_shape)
if rotary: if rotary:
return tensor return tensor
if len(tensor.shape) == 5: if len(tensor.shape) == 5:
...@@ -665,7 +665,7 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -665,7 +665,7 @@ class GPTJModel(GPTJPreTrainedModel):
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(*output_shape) hidden_states = hidden_states.view(output_shape)
# Add last hidden state # Add last hidden state
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -945,7 +945,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): ...@@ -945,7 +945,7 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
f"unexpected if using padding tokens in conjunction with `inputs_embeds.`" f"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
) )
pooled_logits = logits[range(batch_size), sequence_lengths] pooled_logits = logits[torch.arange(batch_size, device=self.device), sequence_lengths]
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -160,7 +160,7 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -160,7 +160,7 @@ class LayoutLMSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -249,7 +249,7 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -249,7 +249,7 @@ class LayoutLMSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -223,7 +223,7 @@ class MegatronBertSelfAttention(nn.Module): ...@@ -223,7 +223,7 @@ class MegatronBertSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -312,7 +312,7 @@ class MegatronBertSelfAttention(nn.Module): ...@@ -312,7 +312,7 @@ class MegatronBertSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -237,7 +237,7 @@ class MobileBertSelfAttention(nn.Module): ...@@ -237,7 +237,7 @@ class MobileBertSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -274,7 +274,7 @@ class MobileBertSelfAttention(nn.Module): ...@@ -274,7 +274,7 @@ class MobileBertSelfAttention(nn.Module):
context_layer = torch.matmul(attention_probs, value_layer) context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
return outputs return outputs
......
...@@ -260,7 +260,7 @@ class RealmSelfAttention(nn.Module): ...@@ -260,7 +260,7 @@ class RealmSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -349,7 +349,7 @@ class RealmSelfAttention(nn.Module): ...@@ -349,7 +349,7 @@ class RealmSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -187,7 +187,7 @@ class RobertaSelfAttention(nn.Module): ...@@ -187,7 +187,7 @@ class RobertaSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -276,7 +276,7 @@ class RobertaSelfAttention(nn.Module): ...@@ -276,7 +276,7 @@ class RobertaSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -127,7 +127,7 @@ class SplinterSelfAttention(nn.Module): ...@@ -127,7 +127,7 @@ class SplinterSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -216,7 +216,7 @@ class SplinterSelfAttention(nn.Module): ...@@ -216,7 +216,7 @@ class SplinterSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
...@@ -181,7 +181,7 @@ class XLMRobertaXLSelfAttention(nn.Module): ...@@ -181,7 +181,7 @@ class XLMRobertaXLSelfAttention(nn.Module):
def transpose_for_scores(self, x): def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3)
def forward( def forward(
...@@ -270,7 +270,7 @@ class XLMRobertaXLSelfAttention(nn.Module): ...@@ -270,7 +270,7 @@ class XLMRobertaXLSelfAttention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape) context_layer = context_layer.view(new_context_layer_shape)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
......
This diff is collapsed.
import copy
import functools
import operator
from inspect import signature
from typing import Any, Callable, Dict, Optional, Union
import torch
from torch.fx import Graph, GraphModule, Node
# Torch FX transformation convention:
# - transformations that are supposed to act on a copy of the original GraphModule are decorated with @transformation
# - transformations that are inplace have a name ending with "_"
def _cache_attributes(gm: GraphModule) -> Dict[str, Any]:
attributes_to_keep = [
"config",
"num_choices",
"dummy_inputs",
"use_dynamic_batch_size",
"use_dynamic_sequence_length",
"static_batch_size",
"static_sequence_length",
"static2dynamic",
"dynamic2static",
]
attributes = {k: getattr(gm, k, None) for k in attributes_to_keep}
return attributes
def _restore_attributes_(gm: GraphModule, attributes: Dict[str, Any]):
for name, attr in attributes.items():
setattr(gm, name, attr)
def deepcopy_graph(gm: GraphModule) -> GraphModule:
"""
Performs a deepcopy of the GraphModule while also copying the relevant attributes to know whether the model was
traced with dynamic axes, and what were the values if that is the case.
"""
# First, create a copy of the module without the graph.
graph = gm.__dict__.pop("_graph")
fake_mod = torch.nn.Module()
fake_mod.__dict__ = copy.deepcopy(gm.__dict__)
gm.__dict__["_graph"] = graph
# Then, copy the graph.
val_map = {}
graph_clone = Graph()
output_val = graph_clone.graph_copy(graph, val_map=val_map)
graph_clone.output(output_val)
# Finally create a new GraphModule (or a subclass of GraphModule) from the module and the graph copies.
# gm.__class__ is used to take into account that gm can be an instance of a subclass of GraphModule.
clone = gm.__class__(fake_mod, graph_clone)
# Restore the dynamic axes related attributes to the clone.
attributes = _cache_attributes(gm)
attributes["dynamic2static"] = {val_map.get(k, k): v for k, v in attributes["dynamic2static"].items()}
attributes["static2dynamic"] = {v: k for k, v in attributes["dynamic2static"].items()}
_restore_attributes_(clone, attributes)
return clone
def transformation(func):
"""
Decorator that wraps a torch.fx transformation by feeding it a copy of the GraphModule to transform instead of the
original.
"""
def map_fn(arg):
if isinstance(arg, GraphModule):
return deepcopy_graph(arg)
return arg
@functools.wraps(func)
def wrapper(*args, **kwargs):
new_args = tuple(map_fn(arg) for arg in args)
new_kwargs = {k: map_fn(v) for k, v in kwargs.items()}
return func(*new_args, **new_kwargs)
wrapper._is_transformation = True
return wrapper
def compose_transformations(
*args: Callable[[GraphModule], Optional[GraphModule]], inplace: bool = False
) -> GraphModule:
"""
Allows to compose transformations together and takes of:
1. Performing the transformations on a copy of the GraphModule if inplace is set to False, transformations that
are decorated with @transformation (which means that they are not modifying the original GraphModule) are
unwrapped to make them inplace.
2. Linting and recompiling only at the end of the composition for performance purposes.
"""
args = list(args)
if not inplace:
args.insert(0, deepcopy_graph)
for i, transformation in enumerate(args[:-1]):
sig = signature(transformation)
# Unwrapping @transformation decorated transformations as performing the transformations inplace or on a copy is
# already handled by this function.
if getattr(transformation, "_is_transformation", False):
transformation = transformation.__wrapped__
# Linting and recompiling only after the last transformation applied to make composition efficient.
if "lint_and_recompile" in sig.parameters:
args[i] = functools.partial(transformation, lint_and_recompile=False)
def reduce_func(f, g):
def compose_f_and_g(gm):
output_g = g(gm)
if output_g is None:
output_g = gm
output_f = f(output_g)
if output_f is None:
output_f = gm
return output_f
return compose_f_and_g
return functools.reduce(reduce_func, reversed(args), lambda x: x)
def remove_unused_nodes_(gm: GraphModule, lint_and_recompile: bool = True):
"""Removes all the unused nodes in a GraphModule."""
graph = gm.graph
for node in graph.nodes:
if not node.users and node.op not in ["placeholder", "output"]:
graph.erase_node(node)
if lint_and_recompile:
graph.lint()
gm.recompile()
def _insert_batch_size_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
"""Inserts a node that retrieves the batch size dynamically from the input of the model."""
graph = gm.graph
input_names = set(gm.dummy_inputs.keys())
batch_size_node = None
for node in graph.nodes:
if node.op == "placeholder" and node.name in input_names:
with graph.inserting_after(node):
batch_size_node = graph.call_method("size", args=(node, 0))
if batch_size_node is None:
raise ValueError("Could not insert the node that computes the batch size")
if lint_and_recompile:
graph.lint()
gm.recompile()
# Useful when retracing for quantization.
if hasattr(gm, "_qconfig_map"):
gm._qconfig_map[batch_size_node.name] = None
return batch_size_node
def _insert_encoder_sequence_length_node_(gm: GraphModule, lint_and_recompile: bool = True) -> Node:
"""Inserts a node that retrieves the encoder sequence length dynamically from the input of the model."""
graph = gm.graph
input_names = set(gm.dummy_inputs.keys())
encoder_sequence_length_node = None
for node in graph.nodes:
if node.op == "placeholder" and node.name in input_names and "decoder" not in node.name:
with graph.inserting_after(node):
# There are two cases to handle:
# 1. num_choices < 0, meaning that the model is not performing a "multiple choice" task, in this case the
# input shapes is [batch_size, sequence_length] => index 1
# 2. num_choices > 0, meaning the model is performing a "multiple choice" task, in this case the input
# shape is [batch_size, num_choices, sequence_length] => index 2
encoder_sequence_length_node = graph.call_method("size", args=(node, 1 if gm.num_choices < 0 else 2))
if encoder_sequence_length_node is None:
raise ValueError("Could not insert the node that computes the encoder sequence length")
if lint_and_recompile:
graph.lint()
gm.recompile()
# Useful when retracing for quantization.
if hasattr(gm, "_qconfig_map"):
gm._qconfig_map[encoder_sequence_length_node.name] = None
return encoder_sequence_length_node
def _change_view_methods_(
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
"""
Changes arguments of view ops that refer to static batch size / sequence lengths to make them refer to the
batch_size / sequence_length nodes.
"""
graph = gm.graph
for node in graph.nodes:
if node.op == "call_method" and node.target == "view":
if isinstance(node.args[1], tuple):
node.args = (node.args[0], *node.args[1])
node.args = tuple((mapping.get(arg, arg) for arg in node.args))
if lint_and_recompile:
graph.lint()
gm.recompile()
def _patch_getitem_(
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
"""Patches getitem nodes by replacing current arguments to their corresponding values in mapping."""
# TODO: combine this with the patch_argument function which seems to do almost the same thing.
graph = gm.graph
for node in graph.nodes:
if node.op == "call_function" and node.target == operator.getitem:
indices = node.args[1]
if isinstance(indices, tuple):
new_indices = []
for idx in indices:
if isinstance(idx, slice):
new_indices.append(
slice(
mapping.get(idx.start, idx.start),
mapping.get(idx.stop, idx.stop),
mapping.get(idx.step, idx.step),
)
)
elif isinstance(idx, int):
new_indices.append(mapping.get(idx, idx))
else:
new_indices.append(idx)
node.args = (node.args[0], tuple(new_indices))
else:
node.args = (node.args[0], mapping.get(node.args[1], node.args[1]))
if lint_and_recompile:
graph.lint()
gm.recompile()
def _patch_arguments_(
gm: GraphModule, mapping: Union[Dict[Node, int], Dict[int, Node]], lint_and_recompile: bool = True
):
"""
Patches node by replacing their argument to their corresponding values in mapping (supports regular types, tuples
and slices).
"""
def _patch_slice(s, mapping):
return slice(mapping.get(s.start, s.start), mapping.get(s.stop, s.stop), mapping.get(s.step, s.step))
graph = gm.graph
supported_types = (Node, str, int, float)
for node in graph.nodes:
new_args = []
for arg in node.args:
if isinstance(arg, tuple):
new_arg = []
for a in arg:
if isinstance(a, slice):
new_arg.append(_patch_slice(a, mapping))
else:
new_arg.append(mapping.get(a, a))
new_args.append(tuple(new_arg))
elif isinstance(arg, slice):
new_args.append(_patch_slice(arg, mapping))
elif isinstance(arg, supported_types):
new_args.append(mapping.get(arg, arg))
else:
new_args.append(arg)
node.args = tuple(new_args)
if lint_and_recompile:
graph.lint()
gm.recompile()
def transform_to_dynamic_input_(gm: GraphModule, is_retracing: bool = False):
"""Transformation that enables traced models to perform inference on dynamic input shapes."""
graph = gm.graph
static2dynamic = {}
# Inserting the nodes that will fetch the batch size and sequence lengths dynamically.
if gm.use_dynamic_batch_size:
batch_size_node = _insert_batch_size_node_(gm, lint_and_recompile=False)
static2dynamic[gm.static_batch_size] = batch_size_node
if gm.num_choices > 0:
with graph.inserting_after(batch_size_node):
static2dynamic[gm.static_batch_size * gm.num_choices] = graph.call_function(
operator.mul, args=(batch_size_node, gm.num_choices)
)
# Useful when retracing for quantization.
if hasattr(gm, "_qconfig_map"):
gm._qconfig_map[static2dynamic[gm.static_batch_size * gm.num_choices]] = None
if gm.use_dynamic_sequence_length:
encoder_sequence_length_node = _insert_encoder_sequence_length_node_(gm, lint_and_recompile=False)
static2dynamic[gm.static_sequence_length[0]] = encoder_sequence_length_node
# TODO: do the same for the decoder.
pass
_change_view_methods_(gm, static2dynamic, lint_and_recompile=False)
_patch_getitem_(gm, static2dynamic, lint_and_recompile=False)
remove_unused_nodes_(gm, lint_and_recompile=False)
graph.lint()
gm.recompile()
gm.static2dynamic = static2dynamic
gm.dynamic2static = {v: k for (k, v) in static2dynamic.items()}
...@@ -231,8 +231,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -231,8 +231,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
fx_ready_model_classes = all_model_classes fx_compatible = True
fx_dynamic_ready_model_classes = all_model_classes
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -444,8 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -444,8 +444,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes fx_compatible = True
fx_dynamic_ready_model_classes = all_model_classes
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment