Unverified Commit c51dc4f9 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[torch] remove deprecated uint8 in favor of bool (#21384)



* uint8 -> bool

* fix copies

* style

* update test modeling commen when checking attention buffers

* style

* use logical not on random mask instead of subtraction with 1

* remove torch uint8

* quality

* remove modified modeling utils

* Update based on review
Co-authored-by: default avatarsgugger <sylvain.gugger@gmail.com>

---------
Co-authored-by: default avatarsgugger <sylvain.gugger@gmail.com>
parent cc44e72d
......@@ -97,7 +97,7 @@ class CodeGenAttention(nn.Module):
max_positions = config.max_position_embeddings
self.register_buffer(
"causal_mask",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
......
......@@ -145,7 +145,7 @@ class XSoftmax(torch.autograd.Function):
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
)
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
class DropoutContext(object):
......
......@@ -136,7 +136,7 @@ class XSoftmax(torch.autograd.Function):
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
)
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
......
......@@ -115,7 +115,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
......@@ -181,7 +181,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
......
......@@ -127,7 +127,7 @@ class GPT2Attention(nn.Module):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
......@@ -193,7 +193,7 @@ class GPT2Attention(nn.Module):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
......
......@@ -133,7 +133,7 @@ class GPTNeoSelfAttention(nn.Module):
super().__init__()
max_positions = config.max_position_embeddings
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=bool)).view(
1, 1, max_positions, max_positions
)
......@@ -187,7 +187,7 @@ class GPTNeoSelfAttention(nn.Module):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
......
......@@ -86,7 +86,7 @@ class GPTNeoXAttention(nn.Module):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
......@@ -193,7 +193,7 @@ class GPTNeoXAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
......
......@@ -180,13 +180,13 @@ class GPTNeoXJapaneseAttention(nn.Module):
# -> [bs, seq_len, hidden_size]
return tensor
def _create_casual_mask(self, key_length, query_length):
casual_mask = torch.tril(
torch.ones((self.max_positions, self.max_positions), dtype=torch.uint8).view(
def _create_causal_mask(self, key_length, query_length):
causal_mask = torch.tril(
torch.ones((self.max_positions, self.max_positions), dtype=torch.bool).view(
1, 1, self.max_positions, self.max_positions
)
)
return casual_mask[:, :, key_length - query_length : key_length, :key_length].bool()
return causal_mask[:, :, key_length - query_length : key_length, :key_length]
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
......@@ -194,7 +194,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
causal_mask = self._create_casual_mask(key_length, query_length)
causal_mask = self._create_causal_mask(key_length, query_length)
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
......
......@@ -78,7 +78,7 @@ def convert_megatron_checkpoint(sd_megatron, config):
pf = "model.language_model.encoder.layers."
for i in range(layers):
causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.uint8))
causal_mask = torch.tril(torch.ones((n_positions, n_positions), dtype=torch.bool))
causal_mask = causal_mask.view(1, 1, n_positions, n_positions)
sd_hf[f"transformer.h.{i}.attn.bias"] = causal_mask
sd_hf[f"transformer.h.{i}.attn.masked_bias"] = torch.tensor(-1e4, dtype=torch.bfloat16)
......
......@@ -90,7 +90,7 @@ class GPTJAttention(nn.Module):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
......@@ -155,7 +155,7 @@ class GPTJAttention(nn.Module):
):
# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float32)
......
......@@ -180,7 +180,7 @@ class ImageGPTAttention(nn.Module):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
......@@ -244,7 +244,7 @@ class ImageGPTAttention(nn.Module):
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
......
......@@ -71,7 +71,7 @@ def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
sorted_indices_to_remove[..., 0] = 0
# indices_to_remove = sorted_indices[sorted_indices_to_remove]
indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_(
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
......
......@@ -404,12 +404,12 @@ def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=Tru
# bool attention mask with True in locations of global attention
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
if before_sep_token is True:
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8)
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.bool)
else:
# last token is separation token and should not be counted and in the middle are two separation tokens
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * (
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.bool) * (
attention_mask.expand_as(input_ids) < input_ids.shape[-1]
).to(torch.uint8)
).to(torch.bool)
return attention_mask
......
......@@ -666,7 +666,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# add an extra bucket for padding tokens only
num_buckets = num_buckets + 1
# assign padding tokens extra bucket
buckets_mask = attention_mask.to(torch.uint8)[:, None, None, :].expand(buckets.shape)
buckets_mask = attention_mask.to(torch.bool)[:, None, None, :].expand(buckets.shape)
buckets = torch.where(
buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device)
)
......@@ -841,7 +841,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
# attention mask for LSH
if attention_mask is not None:
# if chunked attention, the attention mask has to correspond to LSH order
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
attention_mask = attention_mask.to(torch.bool)[:, None, :]
if not do_standard_self_attention:
# expand attn_mask to fit with key_value_bucket_idx shape
attention_mask = attention_mask[:, None, :]
......@@ -1225,7 +1225,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
):
# chunk attention mask and look before and after
if attention_mask is not None:
attention_mask = attention_mask.to(torch.uint8)[:, None, :]
attention_mask = attention_mask.to(torch.bool)[:, None, :]
if not do_standard_self_attention:
attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1)
......@@ -2159,8 +2159,8 @@ class ReformerModel(ReformerPreTrainedModel):
else:
attention_mask = torch.cat(
[
torch.ones(input_shape, device=device, dtype=torch.uint8),
torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.uint8),
torch.ones(input_shape, device=device, dtype=torch.bool),
torch.zeros((input_shape[0], padding_length), device=device, dtype=torch.bool),
],
dim=-1,
)
......
......@@ -566,7 +566,7 @@ class XSoftmax(torch.autograd.Function):
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
)
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.bool)))
# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
......
......@@ -927,7 +927,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mlen = mems[0].size(0) if mems is not None else 0
klen = mlen + qlen
if self.same_length:
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
all_ones = word_emb.new_ones((qlen, klen), dtype=torch.bool)
mask_len = klen - self.mem_len
if mask_len > 0:
mask_shift_len = qlen - mask_len
......@@ -935,7 +935,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
mask_shift_len = qlen
dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len))[:, :, None] # -1
else:
dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.uint8), diagonal=1 + mlen)[
dec_attn_mask = torch.triu(word_emb.new_ones((qlen, klen), dtype=torch.bool), diagonal=1 + mlen)[
:, :, None
]
......
......@@ -442,8 +442,11 @@ class ModelTesterMixin:
# Before we test anything
for key in model_fast_init.state_dict().keys():
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
max_diff = (model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]).sum().item()
else:
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_save_load_fast_init_to_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -490,10 +493,15 @@ class ModelTesterMixin:
model_slow_init = base_class_copy.from_pretrained(tmpdirname, _fast_init=False)
for key in model_fast_init.state_dict().keys():
if isinstance(model_slow_init.state_dict()[key], torch.BoolTensor):
max_diff = torch.max(
model_slow_init.state_dict()[key] ^ model_fast_init.state_dict()[key]
).item()
else:
max_diff = torch.max(
torch.abs(model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key])
).item()
self.assertLessEqual(max_diff, 1e-5, msg=f"{key} not identical")
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment