Commit 798858f9 authored by Tri Dao's avatar Tri Dao
Browse files

Fix test_baichuan

parent 7b33743a
......@@ -633,7 +633,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
hidden_states = self.transformer(
input_ids, position_ids=position_ids, inference_params=inference_params
)
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
if inference_params is not None:
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
if self.project_out is not None:
......
......@@ -607,7 +607,7 @@ class MHA(nn.Module):
)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
rotary_max_seqlen = (
inference_params.max_sequene_len if inference_params is not None else None
inference_params.max_sequence_len if inference_params is not None else None
)
if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None
......@@ -859,7 +859,7 @@ class ParallelMHA(nn.Module):
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
rotary_max_seqlen = (
inference_params.max_sequene_len if inference_params is not None else None
inference_params.max_sequence_len if inference_params is not None else None
)
if self.num_heads_kv == self.num_heads:
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
......
......@@ -29,19 +29,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize("model_name", ["Baichuan-7B"])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_state_dict(model_name):
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name)
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dict = remap_state_dict_hf_baichuan(ckpt_state_dicts[0], config)
model = GPTLMHeadModel(
config, device="meta"
) # Without device='meta' init is very slow
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
assert state_dict.keys() == pretrained_state_dict.keys()
......@@ -49,20 +45,16 @@ def test_baichuan_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["Baichuan-7B"])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_optimized(model_name):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
dtype = torch.float16
device = "cuda"
config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name)
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
......@@ -70,11 +62,9 @@ def test_baichuan_optimized(model_name):
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [
remap_state_dict_hf_baichuan(s, config) for s in ckpt_state_dicts
]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
......@@ -96,7 +86,7 @@ def test_baichuan_optimized(model_name):
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, device_map="auto", trust_remote_code=True
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
......@@ -105,10 +95,7 @@ def test_baichuan_optimized(model_name):
del model_ref
model_hf = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name,
torch_dtype=dtype,
device_map={"": device},
trust_remote_code=True,
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True,
)
model_hf.eval()
with torch.no_grad():
......@@ -133,23 +120,19 @@ def test_baichuan_optimized(model_name):
).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel"
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["Baichuan-7B"])
def test_baichuan_parallel(model_name, world_size):
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_parallel_forward(model_name, world_size):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
from apex.transformer import parallel_state
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
dtype = torch.float16
config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name)
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
......@@ -165,11 +148,12 @@ def test_baichuan_parallel(model_name, world_size):
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [
remap_state_dict_hf_baichuan(s, config) for s in ckpt_state_dicts
]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
# Need this, otherwise the Triton kernel seems to launched from the wrong device.
torch.cuda.set_device(device)
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(
config, process_group=process_group, device=device, dtype=dtype
......@@ -197,13 +181,12 @@ def test_baichuan_parallel(model_name, world_size):
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model
parallel_state.destroy_model_parallel()
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name,
device_map="auto",
trust_remote_code=True,
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
......@@ -212,10 +195,7 @@ def test_baichuan_parallel(model_name, world_size):
del model_ref
model_hf = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
)
model_hf.eval()
with torch.no_grad():
......@@ -240,16 +220,12 @@ def test_baichuan_parallel(model_name, world_size):
).abs().max().item()
@pytest.mark.parametrize("model_name", ["Baichuan-7B"])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_generation(model_name):
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
dtype = torch.float16
device = "cuda"
config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name)
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
......@@ -257,9 +233,7 @@ def test_baichuan_generation(model_name):
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
tokenizer = AutoTokenizer.from_pretrained(
Path(checkpoint_path) / model_name, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
......@@ -271,10 +245,7 @@ def test_baichuan_generation(model_name):
)
model_hf = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name,
torch_dtype=dtype,
device_map={"": device},
trust_remote_code=True,
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
)
model_hf.eval()
print("HF fp16")
......@@ -292,7 +263,7 @@ def test_baichuan_generation(model_name):
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, device_map="auto", trust_remote_code=True
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
......@@ -301,11 +272,9 @@ def test_baichuan_generation(model_name):
)
del model_ref
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [
remap_state_dict_hf_baichuan(s, config) for s in ckpt_state_dicts
]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
......@@ -368,7 +337,7 @@ def test_baichuan_generation(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["Baichuan-7B"])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
......@@ -376,13 +345,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
"""
from apex.transformer import parallel_state
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
dtype = torch.float16
config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name)
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = False
config.fused_bias_fc = True
......@@ -413,11 +378,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name)
pretrained_state_dicts = [
remap_state_dict_hf_baichuan(s, config) for s in ckpt_state_dicts
]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(
config, process_group=process_group, device=device, dtype=dtype
......@@ -464,10 +427,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
)
model_hf.eval()
print("HF fp16")
......@@ -487,9 +447,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
del model_hf
model_ref = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name,
device_map="auto",
trust_remote_code=True,
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.inference_mode():
......
......@@ -146,6 +146,7 @@ def test_falcon_parallel_forward(model_name, world_size):
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model
parallel_state.destroy_model_parallel()
if rank == 0:
model_hf = AutoModelForCausalLM.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment