Unverified Commit e6707a6e authored by github-actions[bot]'s avatar github-actions[bot] Committed by GitHub
Browse files

[format] applied code formatting on changed files in pull request 5510 (#5517)


Co-authored-by: default avatargithub-actions <github-actions@github.com>
parent 19e1a5cf
......@@ -1302,7 +1302,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
)
if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
......
......@@ -15,10 +15,8 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d
try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
......
......@@ -291,13 +291,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output}
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
)
],
)
}
if self.shard_config.parallel_output:
addon_module[GPT2LMHeadModel].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
addon_module[GPT2LMHeadModel].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
......
......@@ -265,12 +265,18 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output})
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
)
],
)
}
if self.shard_config.parallel_output:
new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
policy.update(new_item)
if self.pipeline_stage_manager:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment