Unverified Commit c51dc4f9 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[torch] remove deprecated uint8 in favor of bool (#21384)



* uint8 -> bool

* fix copies

* style

* update test modeling commen when checking attention buffers

* style

* use logical not on random mask instead of subtraction with 1

* remove torch uint8

* quality

* remove modified modeling utils

* Update based on review
Co-authored-by: default avatarsgugger <sylvain.gugger@gmail.com>

---------
Co-authored-by: default avatarsgugger <sylvain.gugger@gmail.com>
parent cc44e72d
...@@ -97,7 +97,7 @@ class CodeGenAttention(nn.Module): ...@@ -97,7 +97,7 @@ class CodeGenAttention(nn.Module):
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
self.register_buffer( self.register_buffer(
"causal_mask", "causal_mask",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
), ),
) )
......
...@@ -145,7 +145,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -145,7 +145,7 @@ class XSoftmax(torch.autograd.Function):
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
) )
output = softmax(g, output, dim) output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
class DropoutContext(object): class DropoutContext(object):
......
...@@ -136,7 +136,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -136,7 +136,7 @@ class XSoftmax(torch.autograd.Function):
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
) )
output = softmax(g, output, dim) output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext # Copied from transformers.models.deberta.modeling_deberta.DropoutContext
......
...@@ -115,7 +115,7 @@ class DecisionTransformerGPT2Attention(nn.Module): ...@@ -115,7 +115,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
self.register_buffer( self.register_buffer(
"bias", "bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
), ),
) )
...@@ -181,7 +181,7 @@ class DecisionTransformerGPT2Attention(nn.Module): ...@@ -181,7 +181,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
if not self.is_cross_attention: if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
......
...@@ -127,7 +127,7 @@ class GPT2Attention(nn.Module): ...@@ -127,7 +127,7 @@ class GPT2Attention(nn.Module):
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
self.register_buffer( self.register_buffer(
"bias", "bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
), ),
) )
...@@ -193,7 +193,7 @@ class GPT2Attention(nn.Module): ...@@ -193,7 +193,7 @@ class GPT2Attention(nn.Module):
if not self.is_cross_attention: if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
......
...@@ -133,7 +133,7 @@ class GPTNeoSelfAttention(nn.Module): ...@@ -133,7 +133,7 @@ class GPTNeoSelfAttention(nn.Module):
super().__init__() super().__init__()
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
) )
...@@ -187,7 +187,7 @@ class GPTNeoSelfAttention(nn.Module): ...@@ -187,7 +187,7 @@ class GPTNeoSelfAttention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = torch.matmul(query, key.transpose(-1, -2))
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
......
...@@ -86,7 +86,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -86,7 +86,7 @@ class GPTNeoXAttention(nn.Module):
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
self.register_buffer( self.register_buffer(
"bias", "bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
), ),
) )
...@@ -193,7 +193,7 @@ class GPTNeoXAttention(nn.Module): ...@@ -193,7 +193,7 @@ class GPTNeoXAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size() batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2) key_length = key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
......
...@@ -180,13 +180,13 @@ class GPTNeoXJapaneseAttention(nn.Module): ...@@ -180,13 +180,13 @@ class GPTNeoXJapaneseAttention(nn.Module):
# -> [bs, seq_len, hidden_size] # -> [bs, seq_len, hidden_size]
return tensor return tensor
def _create_casual_mask(self, key_length, query_length): def _create_causal_mask(self, key_length, query_length):
casual_mask = torch.tril( causal_mask = torch.tril(
torch.ones((self.max_positions, self.max_positions), dtype=torch.uint8).view( torch.ones((self.max_positions, self.max_positions), dtype=torch.bool).view(
1, 1, self.max_positions, self.max_positions 1, 1, self.max_positions, self.max_positions
) )
) )
return casual_mask[:, :, key_length - query_length : key_length, :key_length].bool() return causal_mask[:, :, key_length - query_length : key_length, :key_length]
def _attn(self, query, key, value, attention_mask=None, head_mask=None): def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
...@@ -194,7 +194,7 @@ class GPTNeoXJapaneseAttention(nn.Module): ...@@ -194,7 +194,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size() batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2) key_length = key.size(-2)
causal_mask = self._create_casual_mask(key_length, query_length) causal_mask = self._create_causal_mask(key_length, query_length)
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
......
...@@ -78,7 +78,7 @@ def convert_megatron_checkpoint(sd_megatron, config): ...@@ -78,7 +78,7 @@ def convert_megatron_checkpoint(sd_megatron, config):
pf = "model.language_model.encoder.layers." pf = "model.language_model.encoder.layers."
for i in range(layers): for i in range(layers):
causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.uint8)) causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.bool))
causal_mask = causal_mask.view(1, 1, n_positions, n_positions) causal_mask = causal_mask.view(1, 1, n_positions, n_positions)
sd_hf[f"transformer.h.{i}.attn.bias"] = causal_mask sd_hf[f"transformer.h.{i}.attn.bias"] = causal_mask
sd_hf[f"transformer.h.{i}.attn.masked_bias"] = torch.tensor(-1e4, dtype=torch.bfloat16) sd_hf[f"transformer.h.{i}.attn.masked_bias"] = torch.tensor(-1e4, dtype=torch.bfloat16)
......
...@@ -90,7 +90,7 @@ class GPTJAttention(nn.Module): ...@@ -90,7 +90,7 @@ class GPTJAttention(nn.Module):
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
self.register_buffer( self.register_buffer(
"bias", "bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
), ),
) )
...@@ -155,7 +155,7 @@ class GPTJAttention(nn.Module): ...@@ -155,7 +155,7 @@ class GPTJAttention(nn.Module):
): ):
# compute causal mask from causal mask buffer # compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# Keep the attention weights computation in fp32 to avoid overflow issues # Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32) query = query.to(torch.float32)
......
...@@ -180,7 +180,7 @@ class ImageGPTAttention(nn.Module): ...@@ -180,7 +180,7 @@ class ImageGPTAttention(nn.Module):
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
self.register_buffer( self.register_buffer(
"bias", "bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions 1, 1, max_positions, max_positions
), ),
) )
...@@ -244,7 +244,7 @@ class ImageGPTAttention(nn.Module): ...@@ -244,7 +244,7 @@ class ImageGPTAttention(nn.Module):
if not self.is_cross_attention: if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask # if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2) query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
......
...@@ -71,7 +71,7 @@ def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): ...@@ -71,7 +71,7 @@ def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
sorted_indices_to_remove[..., 0] = 0 sorted_indices_to_remove[..., 0] = 0
# indices_to_remove = sorted_indices[sorted_indices_to_remove] # indices_to_remove = sorted_indices[sorted_indices_to_remove]
indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_( indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove dim=-1, index=sorted_indices, src=sorted_indices_to_remove
) )
logits[indices_to_remove] = filter_value logits[indices_to_remove] = filter_value
......
...@@ -404,12 +404,12 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru ...@@ -404,12 +404,12 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru
# bool attention mask with True in locations of global attention # bool attention mask with True in locations of global attention
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device) attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
if before_sep_token is True: if before_sep_token is True:
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8) attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.bool)
else: else:
# last token is separation token and should not be counted and in the middle are two separation tokens # last token is separation token and should not be counted and in the middle are two separation tokens
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * ( attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.bool) * (
attention_mask.expand_as(input_ids) < input_ids.shape[-1] attention_mask.expand_as(input_ids) < input_ids.shape[-1]
).to(torch.uint8) ).to(torch.bool)
return attention_mask return attention_mask
......
...@@ -666,7 +666,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -666,7 +666,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# add an extra bucket for padding tokens only # add an extra bucket for padding tokens only
num_buckets = num_buckets + 1 num_buckets = num_buckets + 1
# assign padding tokens extra bucket # assign padding tokens extra bucket
buckets_mask = attention_mask.to(torch.uint8)[:, None, None, :].expand(buckets.shape) buckets_mask = attention_mask.to(torch.bool)[:, None, None, :].expand(buckets.shape)
buckets = torch.where( buckets = torch.where(
buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device) buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device)
) )
...@@ -841,7 +841,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -841,7 +841,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# attention mask for LSH # attention mask for LSH
if attention_mask is not None: if attention_mask is not None:
# if chunked attention, the attention mask has to correspond to LSH order # if chunked attention, the attention mask has to correspond to LSH order
attention_mask = attention_mask.to(torch.uint8)[:, None, :] attention_mask = attention_mask.to(torch.bool)[:, None, :]
if not do_standard_self_attention: if not do_standard_self_attention:
# expand attn_mask to fit with key_value_bucket_idx shape # expand attn_mask to fit with key_value_bucket_idx shape
attention_mask = attention_mask[:, None, :] attention_mask = attention_mask[:, None, :]
...@@ -1225,7 +1225,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -1225,7 +1225,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
): ):
# chunk attention mask and look before and after # chunk attention mask and look before and after
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, :] attention_mask = attention_mask.to(torch.bool)[:, None, :]
if not do_standard_self_attention: if not do_standard_self_attention:
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
...@@ -2159,8 +2159,8 @@ class ReformerModel(ReformerPreTrainedModel): ...@@ -2159,8 +2159,8 @@ class ReformerModel(ReformerPreTrainedModel):
else: else:
attention_mask = torch.cat( attention_mask = torch.cat(
[ [
torch.ones(input_shape, device=device, dtype=torch.uint8), torch.ones(input_shape, device=device, dtype=torch.bool),
torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8), torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.bool),
], ],
dim=-1, dim=-1,
) )
......
...@@ -566,7 +566,7 @@ class XSoftmax(torch.autograd.Function): ...@@ -566,7 +566,7 @@ class XSoftmax(torch.autograd.Function):
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
) )
output = softmax(g, output, dim) output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8))) return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext # Copied from transformers.models.deberta.modeling_deberta.DropoutContext
......
...@@ -927,7 +927,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -927,7 +927,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mlen = mems[0].size(0) if mems is not None else 0 mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen klen = mlen + qlen
if self.same_length: if self.same_length:
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8) all_ones = word_emb.new_ones((qlen, klen), dtype=torch.bool)
mask_len = klen - self.mem_len mask_len = klen - self.mem_len
if mask_len > 0: if mask_len > 0:
mask_shift_len = qlen - mask_len mask_shift_len = qlen - mask_len
...@@ -935,7 +935,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -935,7 +935,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mask_shift_len = qlen mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1 dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
else: else:
dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1 + mlen)[ dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.bool), diagonal=1 + mlen)[
:, :, None :, :, None
] ]
......
...@@ -442,8 +442,11 @@ class ModelTesterMixin: ...@@ -442,8 +442,11 @@ class ModelTesterMixin:
# Before we test anything # Before we test anything
for key in model_fast_init.state_dict().keys(): for key in model_fast_init.state_dict().keys():
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item()
else:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item() max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self): def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -490,10 +493,15 @@ class ModelTesterMixin: ...@@ -490,10 +493,15 @@ class ModelTesterMixin:
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False) model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys(): for key in model_fast_init.state_dict().keys():
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
max_diff = torch.max(
model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]
).item()
else:
max_diff = torch.max( max_diff = torch.max(
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]) torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
).item() ).item()
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical") self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_initialization(self): def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment