Commit 31bcf867 authored by Jianghai's avatar Jianghai Committed by Hongxin Liu
Browse files

[pipeline] Llama causal lm and llama for sequence classification pipeline (#4208)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a22568dbeed6d4563372b25e1e825fb0.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision
parent 16220310
...@@ -162,6 +162,24 @@ class Policy(ABC): ...@@ -162,6 +162,24 @@ class Policy(ABC):
return policy return policy
def append_or_create_method_replacement(
self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription],
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Append or create a new method replacement description to the policy for the given key.
Args:
description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
target_key (Union[str, nn.Module]): the key of the policy to be updated
"""
if target_key in policy:
policy[target_key].method_replacement.update(description)
else:
policy[target_key] = ModulePolicyDescription(method_replacement=description)
return policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get layers that should be held in current stage. This method should be implemented by subclass. """Get layers that should be held in current stage. This method should be implemented by subclass.
......
...@@ -131,17 +131,20 @@ class LlamaModelPolicy(LlamaPolicy): ...@@ -131,17 +131,20 @@ class LlamaModelPolicy(LlamaPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
module_policy = super().module_policy() policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel from transformers.models.llama.modeling_llama import LlamaModel
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
# set None as default # set None as default
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
module_policy[LlamaModel] = ModulePolicyDescription(method_replacement={ method_replacement = {
'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index) 'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index)
}) }
return module_policy self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaModel)
return policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
...@@ -158,7 +161,7 @@ class LlamaModelPolicy(LlamaPolicy): ...@@ -158,7 +161,7 @@ class LlamaModelPolicy(LlamaPolicy):
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in bert model""" """No shared params in llama model"""
return [] return []
...@@ -179,8 +182,43 @@ class LlamaForCausalLMPolicy(LlamaPolicy): ...@@ -179,8 +182,43 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
]) ])
} }
policy.update(new_item) policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward': partial(llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaForCausalLM)
return policy return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.model.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.model.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.model.norm)
held_layers.append(module.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
llama_model = self.model.model
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
# tie weights
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}]
return []
class LlamaForSequenceClassificationPolicy(LlamaPolicy): class LlamaForSequenceClassificationPolicy(LlamaPolicy):
...@@ -199,8 +237,42 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ...@@ -199,8 +237,42 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
]) ])
} }
policy.update(new_item) policy.update(new_item)
# to be confirmed
if self.pipeline_stage_manager:
# set None as default
stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(llama_for_sequence_classification_forward,
stage_manager=stage_manager,
stage_index=stage_index)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaForSequenceClassification)
return policy return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.model.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.model.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.model.norm)
held_layers.append(module.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []
def llama_model_forward( def llama_model_forward(
self: LlamaModel, self: LlamaModel,
......
...@@ -52,7 +52,7 @@ loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean() ...@@ -52,7 +52,7 @@ loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
loss_fn = lambda x: x.loss loss_fn = lambda x: x.loss
config = transformers.GPT2Config(n_layer=2, config = transformers.GPT2Config(n_layer=2,
n_head=2, n_head=4,
vocab_size=50258, vocab_size=50258,
attn_pdrop=0, attn_pdrop=0,
embd_pdrop=0, embd_pdrop=0,
......
...@@ -49,21 +49,19 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la ...@@ -49,21 +49,19 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
x = torch.randint(0, 1000, (2, 3)).cuda() x = torch.randint(0, 1000, (2, 3)).cuda()
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name == 'transformers_llama': org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init)
enable_tensor_parallelism, use_lazy_init) if stage_manager.stage == 0:
if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda()
attention_mask = torch.ones_like(x).cuda() output = sharded_model(input_ids=x, attention_mask=attention_mask)
output = sharded_model(input_ids=x, attention_mask=attention_mask) assert output['hidden_states'].shape == (2, 3, 128)
assert output['hidden_states'].shape == (2, 3, 128) else:
else: attention_mask = torch.ones((2, 3)).cuda()
attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(
output = sharded_model( hidden_states=hidden_states,
hidden_states=hidden_states, attention_mask=attention_mask,
attention_mask=attention_mask, )
) assert output[0] is not None
# print(output[0].shape)
assert output[0].shape == (2, 3, 128)
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment