Commit 798858f9 authored by Tri Dao's avatar Tri Dao
Browse files

Fix test_baichuan

parent 7b33743a
...@@ -633,7 +633,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): ...@@ -633,7 +633,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, position_ids=position_ids, inference_params=inference_params input_ids, position_ids=position_ids, inference_params=inference_params
) )
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode" if inference_params is not None:
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
if num_last_tokens > 0: if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:] hidden_states = hidden_states[:, -num_last_tokens:]
if self.project_out is not None: if self.project_out is not None:
......
...@@ -607,7 +607,7 @@ class MHA(nn.Module): ...@@ -607,7 +607,7 @@ class MHA(nn.Module):
) )
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
rotary_max_seqlen = ( rotary_max_seqlen = (
inference_params.max_sequene_len if inference_params is not None else None inference_params.max_sequence_len if inference_params is not None else None
) )
if not self.cross_attn and self.num_heads_kv == self.num_heads: if not self.cross_attn and self.num_heads_kv == self.num_heads:
assert x_kv is None and mixer_subset is None assert x_kv is None and mixer_subset is None
...@@ -859,7 +859,7 @@ class ParallelMHA(nn.Module): ...@@ -859,7 +859,7 @@ class ParallelMHA(nn.Module):
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
rotary_max_seqlen = ( rotary_max_seqlen = (
inference_params.max_sequene_len if inference_params is not None else None inference_params.max_sequence_len if inference_params is not None else None
) )
if self.num_heads_kv == self.num_heads: if self.num_heads_kv == self.num_heads:
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
......
...@@ -29,19 +29,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained ...@@ -29,19 +29,15 @@ from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize("model_name", ["Baichuan-7B"]) @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_state_dict(model_name): def test_baichuan_state_dict(model_name):
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
config = baichuan_config_to_gpt2_config( config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
) )
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dict = remap_state_dict_hf_baichuan(
pretrained_state_dict = remap_state_dict_hf_baichuan(ckpt_state_dicts[0], config) state_dict_from_pretrained(model_name), config
model = GPTLMHeadModel( )
config, device="meta" model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
) # Without device='meta' init is very slow
state_dict = model.state_dict() state_dict = model.state_dict()
assert len(state_dict.keys()) == len(pretrained_state_dict.keys()) assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
assert state_dict.keys() == pretrained_state_dict.keys() assert state_dict.keys() == pretrained_state_dict.keys()
...@@ -49,20 +45,16 @@ def test_baichuan_state_dict(model_name): ...@@ -49,20 +45,16 @@ def test_baichuan_state_dict(model_name):
assert state_dict[k].shape == pretrained_state_dict[k].shape assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["Baichuan-7B"]) @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_optimized(model_name): def test_baichuan_optimized(model_name):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the """Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
config = baichuan_config_to_gpt2_config( config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
) )
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -70,11 +62,9 @@ def test_baichuan_optimized(model_name): ...@@ -70,11 +62,9 @@ def test_baichuan_optimized(model_name):
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dict = remap_state_dict_hf_baichuan(
pretrained_state_dicts = [ state_dict_from_pretrained(model_name), config
remap_state_dict_hf_baichuan(s, config) for s in ckpt_state_dicts )
]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
model = GPTLMHeadModel(config, device=device, dtype=dtype) model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
...@@ -96,7 +86,7 @@ def test_baichuan_optimized(model_name): ...@@ -96,7 +86,7 @@ def test_baichuan_optimized(model_name):
# Without device_map, the model is loaded on the CPU, which is very slow # Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = AutoModelForCausalLM.from_pretrained( model_ref = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, device_map="auto", trust_remote_code=True model_name, device_map="auto", trust_remote_code=True
) )
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -105,10 +95,7 @@ def test_baichuan_optimized(model_name): ...@@ -105,10 +95,7 @@ def test_baichuan_optimized(model_name):
del model_ref del model_ref
model_hf = AutoModelForCausalLM.from_pretrained( model_hf = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True,
torch_dtype=dtype,
device_map={"": device},
trust_remote_code=True,
) )
model_hf.eval() model_hf.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -133,23 +120,19 @@ def test_baichuan_optimized(model_name): ...@@ -133,23 +120,19 @@ def test_baichuan_optimized(model_name):
).abs().max().item() ).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel" # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["Baichuan-7B"]) @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_parallel(model_name, world_size): def test_baichuan_parallel_forward(model_name, world_size):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the """Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32. forward pass in fp16, when compared to the HF forward pass in fp32.
""" """
from apex.transformer import parallel_state from apex.transformer import parallel_state
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
dtype = torch.float16 dtype = torch.float16
config = baichuan_config_to_gpt2_config( config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
) )
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -165,11 +148,12 @@ def test_baichuan_parallel(model_name, world_size): ...@@ -165,11 +148,12 @@ def test_baichuan_parallel(model_name, world_size):
rank = parallel_state.get_tensor_model_parallel_rank() rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group() process_group = parallel_state.get_tensor_model_parallel_group()
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) # Need this, otherwise the Triton kernel seems to launched from the wrong device.
pretrained_state_dicts = [ torch.cuda.set_device(device)
remap_state_dict_hf_baichuan(s, config) for s in ckpt_state_dicts
] pretrained_state_dict = remap_state_dict_hf_baichuan(
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config) state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel( model = GPTLMHeadModel(
config, process_group=process_group, device=device, dtype=dtype config, process_group=process_group, device=device, dtype=dtype
...@@ -197,13 +181,12 @@ def test_baichuan_parallel(model_name, world_size): ...@@ -197,13 +181,12 @@ def test_baichuan_parallel(model_name, world_size):
logits, _ = all_gather_raw(logits, process_group) logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size) logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model del model
parallel_state.destroy_model_parallel()
if rank == 0: if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow # Without device_map, the model is loaded on the CPU, which is very slow
model_ref = AutoModelForCausalLM.from_pretrained( model_ref = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, model_name, device_map="auto", trust_remote_code=True
device_map="auto",
trust_remote_code=True,
) )
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -212,10 +195,7 @@ def test_baichuan_parallel(model_name, world_size): ...@@ -212,10 +195,7 @@ def test_baichuan_parallel(model_name, world_size):
del model_ref del model_ref
model_hf = AutoModelForCausalLM.from_pretrained( model_hf = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
) )
model_hf.eval() model_hf.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -240,16 +220,12 @@ def test_baichuan_parallel(model_name, world_size): ...@@ -240,16 +220,12 @@ def test_baichuan_parallel(model_name, world_size):
).abs().max().item() ).abs().max().item()
@pytest.mark.parametrize("model_name", ["Baichuan-7B"]) @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_generation(model_name): def test_baichuan_generation(model_name):
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
dtype = torch.float16 dtype = torch.float16
device = "cuda" device = "cuda"
config = baichuan_config_to_gpt2_config( config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
) )
config.use_flash_attn = True config.use_flash_attn = True
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -257,9 +233,7 @@ def test_baichuan_generation(model_name): ...@@ -257,9 +233,7 @@ def test_baichuan_generation(model_name):
config.fused_dropout_add_ln = True config.fused_dropout_add_ln = True
config.residual_in_fp32 = True config.residual_in_fp32 = True
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
Path(checkpoint_path) / model_name, trust_remote_code=True
)
eos_token_id = tokenizer.eos_token_id eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0) torch.manual_seed(0)
...@@ -271,10 +245,7 @@ def test_baichuan_generation(model_name): ...@@ -271,10 +245,7 @@ def test_baichuan_generation(model_name):
) )
model_hf = AutoModelForCausalLM.from_pretrained( model_hf = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
torch_dtype=dtype,
device_map={"": device},
trust_remote_code=True,
) )
model_hf.eval() model_hf.eval()
print("HF fp16") print("HF fp16")
...@@ -292,7 +263,7 @@ def test_baichuan_generation(model_name): ...@@ -292,7 +263,7 @@ def test_baichuan_generation(model_name):
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = AutoModelForCausalLM.from_pretrained( model_ref = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, device_map="auto", trust_remote_code=True model_name, device_map="auto", trust_remote_code=True
) )
model_ref.eval() model_ref.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -301,11 +272,9 @@ def test_baichuan_generation(model_name): ...@@ -301,11 +272,9 @@ def test_baichuan_generation(model_name):
) )
del model_ref del model_ref
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dict = remap_state_dict_hf_baichuan(
pretrained_state_dicts = [ state_dict_from_pretrained(model_name), config
remap_state_dict_hf_baichuan(s, config) for s in ckpt_state_dicts )
]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
model = GPTLMHeadModel(config, device=device, dtype=dtype) model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict) model.load_state_dict(pretrained_state_dict)
model.eval() model.eval()
...@@ -368,7 +337,7 @@ def test_baichuan_generation(model_name): ...@@ -368,7 +337,7 @@ def test_baichuan_generation(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation" # torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation"
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["Baichuan-7B"]) @pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_parallel_generation(model_name, world_size): def test_baichuan_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation: """Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to the scores in fp16 should be around the same as the HF scores in fp16, when compared to
...@@ -376,13 +345,9 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -376,13 +345,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
""" """
from apex.transformer import parallel_state from apex.transformer import parallel_state
checkpoint_path = Path(
os.environ.get("CHECKPOINT_DIR", current_dir.parent.parent / "checkpoints")
)
dtype = torch.float16 dtype = torch.float16
config = baichuan_config_to_gpt2_config( config = baichuan_config_to_gpt2_config(
config_from_checkpoint(checkpoint_path, model_name) AutoConfig.from_pretrained(model_name, trust_remote_code=True)
) )
config.use_flash_attn = False config.use_flash_attn = False
config.fused_bias_fc = True config.fused_bias_fc = True
...@@ -413,11 +378,9 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -413,11 +378,9 @@ def test_baichuan_parallel_generation(model_name, world_size):
# GPU0 and GPU1 and things would hang # GPU0 and GPU1 and things would hang
torch.cuda.set_device(device) torch.cuda.set_device(device)
ckpt_state_dicts = state_dicts_from_checkpoint(checkpoint_path, model_name) pretrained_state_dict = remap_state_dict_hf_baichuan(
pretrained_state_dicts = [ state_dict_from_pretrained(model_name), config
remap_state_dict_hf_baichuan(s, config) for s in ckpt_state_dicts )
]
pretrained_state_dict = combine_state_dicts_tp(pretrained_state_dicts, config)
model = GPTLMHeadModel( model = GPTLMHeadModel(
config, process_group=process_group, device=device, dtype=dtype config, process_group=process_group, device=device, dtype=dtype
...@@ -464,10 +427,7 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -464,10 +427,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
if rank == 0: if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow # Without device_map, the model is loaded on the CPU, which is very slow
model_hf = AutoModelForCausalLM.from_pretrained( model_hf = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
) )
model_hf.eval() model_hf.eval()
print("HF fp16") print("HF fp16")
...@@ -487,9 +447,7 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -487,9 +447,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
del model_hf del model_hf
model_ref = AutoModelForCausalLM.from_pretrained( model_ref = AutoModelForCausalLM.from_pretrained(
Path(checkpoint_path) / model_name, model_name, device_map="auto", trust_remote_code=True
device_map="auto",
trust_remote_code=True,
) )
model_ref.eval() model_ref.eval()
with torch.inference_mode(): with torch.inference_mode():
......
...@@ -146,6 +146,7 @@ def test_falcon_parallel_forward(model_name, world_size): ...@@ -146,6 +146,7 @@ def test_falcon_parallel_forward(model_name, world_size):
logits, _ = all_gather_raw(logits, process_group) logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size) logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model del model
parallel_state.destroy_model_parallel()
if rank == 0: if rank == 0:
model_hf = AutoModelForCausalLM.from_pretrained( model_hf = AutoModelForCausalLM.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment