Commit 73df3be7 authored by Tri Dao's avatar Tri Dao
Browse files

Add test for BTLM init

parent 7ffba9a5
...@@ -396,7 +396,9 @@ def _init_weights( ...@@ -396,7 +396,9 @@ def _init_weights(
mup_init_scale = math.sqrt(mup_width_scale) mup_init_scale = math.sqrt(mup_width_scale)
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=initializer_range * mup_init_scale) nn.init.normal_(module.weight, std=initializer_range * mup_init_scale)
module.weight._optim = {"lr_multiplier": mup_width_scale} optim_cfg = getattr(module.weight, "_optim", {})
optim_cfg.update({"lr_multiplier": mup_width_scale})
setattr(module.weight, "_optim", optim_cfg)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding): elif isinstance(module, nn.Embedding):
......
# Copyright (c) 2023, Tri Dao. # Copyright (c) 2023, Tri Dao.
import os
import time import time
from pathlib import Path
import torch import torch
import pytest import pytest
from einops import rearrange
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel from flash_attn.models.gpt import GPTLMHeadModel
...@@ -21,9 +17,7 @@ def test_btlm_state_dict(model_name): ...@@ -21,9 +17,7 @@ def test_btlm_state_dict(model_name):
config = btlm_config_to_gpt2_config( config = btlm_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
) )
pretrained_state_dict = remap_state_dict_hf_btlm( pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict() state_dict = model.state_dict()
assert len(state_dict.keys()) == len(pretrained_state_dict.keys()) assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
...@@ -47,9 +41,7 @@ def test_btlm_optimized(model_name): ...@@ -47,9 +41,7 @@ def test_btlm_optimized(model_name):
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
pretrained_state_dict = remap_state_dict_hf_btlm( pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device=device, dtype=dtype) model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
...@@ -152,9 +144,7 @@ def test_btlm_generation(model_name): ...@@ -152,9 +144,7 @@ def test_btlm_generation(model_name):
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device) logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
del model_ref del model_ref
pretrained_state_dict = remap_state_dict_hf_btlm( pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device=device, dtype=dtype) model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
...@@ -212,3 +202,44 @@ def test_btlm_generation(model_name): ...@@ -212,3 +202,44 @@ def test_btlm_generation(model_name):
assert torch.equal(logits_cg, logits) assert torch.equal(logits_cg, logits)
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
def test_btlm_init(model_name):
dtype = torch.float32
device = "cuda"
btlm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config = btlm_config_to_gpt2_config(btlm_config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model_ref = AutoModelForCausalLM.from_config(btlm_config, trust_remote_code=True).to(device)
assert model.transformer.embeddings.word_embeddings.weight.mean().abs() < 1e-4
assert (
model.transformer.embeddings.word_embeddings.weight.std()
- model_ref.transformer.wte.weight.std()
).abs() < 1e-4
assert model.lm_head.weight.mean().abs() < 1e-4
assert (model.lm_head.weight.std() - model_ref.lm_head.weight.std()).abs() < 1e-4
for l in range(config.n_layer):
assert model.transformer.layers[l].mixer.Wqkv.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mixer.Wqkv.weight.std()
- model_ref.transformer.h[l].attn.c_attn.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mixer.Wqkv.bias.abs().max() == 0.0
assert model.transformer.layers[l].mixer.out_proj.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mixer.out_proj.weight.std()
- model_ref.transformer.h[l].attn.c_proj.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mixer.out_proj.bias.abs().max() == 0.0
assert model.transformer.layers[l].mlp.fc1.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mlp.fc1.weight.std()
- model_ref.transformer.h[l].mlp.c_fc.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mlp.fc1.bias.abs().max() == 0.0
assert model.transformer.layers[l].mlp.fc2.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mlp.fc2.weight.std()
- model_ref.transformer.h[l].mlp.c_proj.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mlp.fc2.bias.abs().max() == 0.0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment