Unverified Commit ac98a88f authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix torchscript tests for GPT-NeoX (#18012)



* fix dtype issue in _attn

* fix RotaryEmbedding

* fix RotaryEmbedding 2

* clean up
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 95113d13
......@@ -91,7 +91,9 @@ class GPTNeoXAttention(nn.Module):
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.rotary_emb = RotaryEmbedding(self.rotary_ndims, base=config.rotary_emb_base)
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
)
self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype())
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
......@@ -207,7 +209,7 @@ class GPTNeoXAttention(nn.Module):
query,
key.transpose(1, 2),
beta=1.0,
alpha=(1.0 / self.norm_factor),
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
......@@ -238,17 +240,24 @@ def attention_mask_func(attention_scores, ltor_mask):
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, device=None):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
......@@ -256,7 +265,7 @@ class RotaryEmbedding(torch.nn.Module):
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
def rotate_half(x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment