Unverified Commit 400c5a15 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[megatron gpt checkpoint conversion] causal mask requires pos_embed dimension (#13735)

parent 91df4551
......@@ -121,12 +121,11 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
# The position embeddings.
pos_embeddings = embeddings["position_embeddings"]["weight"]
# Read the hidden dimension.
n_embed = pos_embeddings.size(1)
# DEBUG.
# Read the causal mask dimension (seqlen). [max_sequence_length, hidden_size]
n_ctx = pos_embeddings.size(0)
assert (
n_embed == heads * hidden_size_per_head
), f"detected mismatch n_embed={n_embed} != heads={heads}*hidden_size_per_head={hidden_size_per_head}"
n_ctx == config.n_ctx
), f"pos_embeddings.max_sequence_length={n_ctx} and config.n_ctx={config.n_ctx} don't match"
# Store the position embeddings.
output_state_dict["transformer.wpe.weight"] = pos_embeddings
......@@ -175,7 +174,7 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
) and weight_or_bias == "weight":
# Insert a tensor of 1x1xDxD bias.
causal_mask = torch.tril(torch.ones((n_embed, n_embed), dtype=torch.float16)).view(1, 1, n_embed, n_embed)
causal_mask = torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.float16)).view(1, 1, n_ctx, n_ctx)
output_state_dict[layer_name + ".attn.bias"] = causal_mask
# Insert a "dummy" tensor for masked_bias.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment