Unverified Commit beb2a096 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

DeepSpeed: hardcode `torch.arange` dtype on `float` usage to avoid incorrect...

DeepSpeed: hardcode `torch.arange` dtype on `float` usage to avoid incorrect initialization (#28760)
parent f7076cd3
...@@ -313,8 +313,8 @@ class SpeechT5SinusoidalPositionalEmbedding(nn.Module): ...@@ -313,8 +313,8 @@ class SpeechT5SinusoidalPositionalEmbedding(nn.Module):
""" """
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1: if embedding_dim % 2 == 1:
# zero pad # zero pad
...@@ -403,7 +403,7 @@ class SpeechT5ScaledPositionalEncoding(nn.Module): ...@@ -403,7 +403,7 @@ class SpeechT5ScaledPositionalEncoding(nn.Module):
def __init__(self, dropout, dim, max_len=5000): def __init__(self, dropout, dim, max_len=5000):
pe = torch.zeros(max_len, dim) pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len).unsqueeze(1) position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))) div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.int64).float() * -(math.log(10000.0) / dim)))
pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0) pe = pe.unsqueeze(0)
......
...@@ -290,8 +290,8 @@ class Swin2SRSelfAttention(nn.Module): ...@@ -290,8 +290,8 @@ class Swin2SRSelfAttention(nn.Module):
) )
# get relative_coords_table # get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
relative_coords_table = ( relative_coords_table = (
torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
.permute(1, 2, 0) .permute(1, 2, 0)
......
...@@ -446,8 +446,8 @@ class Swinv2SelfAttention(nn.Module): ...@@ -446,8 +446,8 @@ class Swinv2SelfAttention(nn.Module):
) )
# get relative_coords_table # get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.int64).float()
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.int64).float()
relative_coords_table = ( relative_coords_table = (
torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) torch.stack(meshgrid([relative_coords_h, relative_coords_w], indexing="ij"))
.permute(1, 2, 0) .permute(1, 2, 0)
......
...@@ -371,7 +371,7 @@ class TableTransformerSinePositionEmbedding(nn.Module): ...@@ -371,7 +371,7 @@ class TableTransformerSinePositionEmbedding(nn.Module):
y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale
dim_t = torch.arange(self.embedding_dim, dtype=torch.float32, device=pixel_values.device) dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float()
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim)
pos_x = x_embed[:, :, :, None] / dim_t pos_x = x_embed[:, :, :, None] / dim_t
......
...@@ -85,8 +85,8 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module): ...@@ -85,8 +85,8 @@ class TrOCRSinusoidalPositionalEmbedding(nn.Module):
""" """
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1: if embedding_dim % 2 == 1:
# zero pad # zero pad
......
...@@ -264,7 +264,7 @@ class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module): ...@@ -264,7 +264,7 @@ class Wav2Vec2BertRotaryPositionalEmbedding(nn.Module):
dim = config.hidden_size // config.num_attention_heads dim = config.hidden_size // config.num_attention_heads
base = config.rotary_embedding_base base = config.rotary_embedding_base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
# Ignore copy # Ignore copy
self.register_buffer("inv_freq", inv_freq, persistent=False) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.cached_sequence_length = None self.cached_sequence_length = None
...@@ -314,9 +314,9 @@ class Wav2Vec2BertRelPositionalEmbedding(nn.Module): ...@@ -314,9 +314,9 @@ class Wav2Vec2BertRelPositionalEmbedding(nn.Module):
# are to the left (i>j) and negative relative positions otherwise (i<j). # are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model) pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model) pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) position = torch.arange(0, x.size(1), dtype=torch.int64).float().unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) torch.arange(0, self.d_model, 2, dtype=torch.int64).float() * -(math.log(10000.0) / self.d_model)
) )
pe_positive[:, 0::2] = torch.sin(position * div_term) pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term) pe_positive[:, 1::2] = torch.cos(position * div_term)
......
...@@ -396,7 +396,7 @@ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module): ...@@ -396,7 +396,7 @@ class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module):
dim = config.hidden_size // config.num_attention_heads dim = config.hidden_size // config.num_attention_heads
base = config.rotary_embedding_base base = config.rotary_embedding_base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self.cached_sequence_length = None self.cached_sequence_length = None
self.cached_rotary_positional_embedding = None self.cached_rotary_positional_embedding = None
...@@ -444,9 +444,9 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module): ...@@ -444,9 +444,9 @@ class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module):
# are to the left (i>j) and negative relative positions otherwise (i<j). # are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(1), self.d_model) pe_positive = torch.zeros(x.size(1), self.d_model)
pe_negative = torch.zeros(x.size(1), self.d_model) pe_negative = torch.zeros(x.size(1), self.d_model)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) position = torch.arange(0, x.size(1), dtype=torch.int64).float().unsqueeze(1)
div_term = torch.exp( div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) torch.arange(0, self.d_model, 2, dtype=torch.int64).float() * -(math.log(10000.0) / self.d_model)
) )
pe_positive[:, 0::2] = torch.sin(position * div_term) pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term) pe_positive[:, 1::2] = torch.cos(position * div_term)
......
...@@ -157,8 +157,8 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module): ...@@ -157,8 +157,8 @@ class XGLMSinusoidalPositionalEmbedding(nn.Module):
""" """
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1: if embedding_dim % 2 == 1:
# zero pad # zero pad
......
...@@ -1020,7 +1020,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1020,7 +1020,7 @@ class XLNetModel(XLNetPreTrainedModel):
def relative_positional_encoding(self, qlen, klen, bsz=None): def relative_positional_encoding(self, qlen, klen, bsz=None):
# create relative positional encoding. # create relative positional encoding.
freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float) freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.int64).float()
inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model)) inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))
if self.attn_type == "bi": if self.attn_type == "bi":
...@@ -1033,8 +1033,8 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1033,8 +1033,8 @@ class XLNetModel(XLNetPreTrainedModel):
raise ValueError(f"Unknown `attn_type` {self.attn_type}.") raise ValueError(f"Unknown `attn_type` {self.attn_type}.")
if self.bi_data: if self.bi_data:
fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float) fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()
bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float) bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.int64).float()
if self.clamp_len > 0: if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
...@@ -1049,7 +1049,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -1049,7 +1049,7 @@ class XLNetModel(XLNetPreTrainedModel):
pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1) pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
else: else:
fwd_pos_seq = torch.arange(beg, end, -1.0) fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.int64).float()
if self.clamp_len > 0: if self.clamp_len > 0:
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
......
...@@ -25,6 +25,7 @@ import datasets ...@@ -25,6 +25,7 @@ import datasets
from parameterized import parameterized from parameterized import parameterized
import tests.trainer.test_trainer import tests.trainer.test_trainer
import transformers
from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa from tests.trainer.test_trainer import TrainerIntegrationCommon # noqa
from transformers import AutoModel, TrainingArguments, is_torch_available, logging from transformers import AutoModel, TrainingArguments, is_torch_available, logging
from transformers.integrations.deepspeed import ( from transformers.integrations.deepspeed import (
...@@ -53,6 +54,8 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_available_on_dev ...@@ -53,6 +54,8 @@ from transformers.utils import SAFE_WEIGHTS_NAME, is_torch_bf16_available_on_dev
if is_torch_available(): if is_torch_available():
import torch
from tests.trainer.test_trainer import ( # noqa from tests.trainer.test_trainer import ( # noqa
RegressionModelConfig, RegressionModelConfig,
RegressionPreTrainedModel, RegressionPreTrainedModel,
...@@ -70,6 +73,7 @@ DEFAULT_MASTER_PORT = "10999" ...@@ -70,6 +73,7 @@ DEFAULT_MASTER_PORT = "10999"
T5_SMALL = "t5-small" T5_SMALL = "t5-small"
T5_TINY = "patrickvonplaten/t5-tiny-random" T5_TINY = "patrickvonplaten/t5-tiny-random"
GPT2_TINY = "sshleifer/tiny-gpt2" GPT2_TINY = "sshleifer/tiny-gpt2"
GPTJ_TINY = "hf-internal-testing/tiny-random-gptj"
def load_json(path): def load_json(path):
...@@ -297,6 +301,74 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon): ...@@ -297,6 +301,74 @@ class CoreIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)), torch.allclose(model.new_head.bias, torch.tensor(+100.0, device=model.new_head.bias.device)),
) )
def test_arange_bf16(self):
# Tests that configuring DeepSpeed with 16 bits does not cause float `torch.arange()` tensors to be cast down.
# NOTE -- this assumes that the function calls have the following downcast-preventing pattern, i.e.
# `torch.arange(...,dtype=torch.int64)` followed by a cast like `.to(torch.float32)`. 🚨 If this pattern is
# NOT applied (e.g. `torch.arange(...,dtype=torch.float32)` is used), DeepSpeed can automatically cast it down
# at init time. See https://github.com/huggingface/transformers/issues/28685 for more info.
ds_config = {
"train_batch_size": 1,
"zero_optimization": {
"stage": 3,
},
"bf16": {"enabled": True},
}
dschf = HfDeepSpeedConfig(ds_config)
self.assertTrue(dschf.is_zero3())
self.assertTrue(is_deepspeed_zero3_enabled())
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
model = AutoModel.from_pretrained(GPTJ_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
# The model weights are in BF16 as per deepspeed config
self.assertTrue(str(model.h[0].attn.q_proj.weight.dtype) == "torch.bfloat16")
good_deepspeed_sin_cos = model.h[0].attn.embed_positions
# Monkeypatches the function that creates RoPE embeddings using the INCORRECT torch.arange() pattern, and
# then recreates the model
def bad_deepspeed_create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor:
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
# Incorrect pattern here: torch.arange has dtype=torch.float32 as its argument, and it will automatically
# converted to BF16 by DeepSpeed
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=inv_freq.dtype), inv_freq)
return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1)
good_deepspeed_create_sinusoidal_positions = transformers.models.gptj.modeling_gptj.create_sinusoidal_positions
transformers.models.gptj.modeling_gptj.create_sinusoidal_positions = bad_deepspeed_create_sinusoidal_positions
with LoggingLevel(logging.INFO):
with mockenv_context(**self.dist_env_1_gpu):
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
model = AutoModel.from_pretrained(GPTJ_TINY)
self.assertIn("Detected DeepSpeed ZeRO-3", cl.out)
self.assertTrue(str(model.h[0].attn.q_proj.weight.dtype) == "torch.bfloat16")
bad_deepspeed_sin_cos = model.h[0].attn.embed_positions
# Compares the two values: the two sets of values are different, and the correct one matches the torch
# (i.e. outside DeepSpeed) version.
good_torch_sin_cos = good_deepspeed_create_sinusoidal_positions(
model.config.max_position_embeddings, model.config.rotary_dim
)
self.assertFalse(torch.allclose(good_deepspeed_sin_cos, bad_deepspeed_sin_cos))
self.assertTrue(torch.allclose(good_torch_sin_cos, good_deepspeed_sin_cos.cpu()))
# Finally, we can see that the incorrect pattern is okay on vanilla torch, demostrating that this issue is
# exclusive to DeepSpeed
bad_torch_sin_cos = bad_deepspeed_create_sinusoidal_positions(
model.config.max_position_embeddings, model.config.rotary_dim
)
self.assertTrue(torch.allclose(bad_torch_sin_cos, good_torch_sin_cos))
class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus): class TrainerIntegrationDeepSpeedWithCustomConfig(TestCasePlus):
def setUp(self): def setUp(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment