Unverified Commit 9736cd3b authored by Swipe4057's avatar Swipe4057 Committed by GitHub
Browse files

[Bugfix] pipeline parallelism and Eagle Qwen2 (#6910)

parent 2f715f51
...@@ -24,13 +24,14 @@ from typing import Iterable, Optional, Tuple ...@@ -24,13 +24,14 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
Qwen2Config = None Qwen2Config = None
...@@ -87,6 +88,7 @@ class Qwen2Model(nn.Module): ...@@ -87,6 +88,7 @@ class Qwen2Model(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
...@@ -119,6 +121,7 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): ...@@ -119,6 +121,7 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.pp_group = get_pp_group()
self.model = Qwen2Model( self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix) config, quant_config=quant_config, prefix=add_prefix("model", prefix)
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment