Unverified Commit 0a25e16e authored by flybird11111's avatar flybird11111 Committed by GitHub
Browse files

[shardformer]gather llama logits (#5398)

* gather llama logits

* fix
parent dcdd8a5e
...@@ -16,6 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager ...@@ -16,6 +16,7 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d from ..layer import cross_entropy_1d
from ..layer._operation import _gather
try: try:
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
...@@ -288,6 +289,9 @@ class LlamaPipelineForwards: ...@@ -288,6 +289,9 @@ class LlamaPipelineForwards:
shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels) loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output return (loss,) + output if loss is not None else output
...@@ -588,6 +592,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): ...@@ -588,6 +592,9 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels) loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
logits = _gather(logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output return (loss,) + output if loss is not None else output
......
...@@ -34,6 +34,7 @@ class ShardConfig: ...@@ -34,6 +34,7 @@ class ShardConfig:
enable_all_optimization: bool = False enable_all_optimization: bool = False
enable_sequence_parallelism: bool = False enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False enable_sequence_overlap: bool = False
parallel_output = True
extra_kwargs: Dict[str, Any] = field(default_factory=dict) extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# pipeline_parallel_size: int # pipeline_parallel_size: int
# data_parallel_size: int # data_parallel_size: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment