"vscode:/vscode.git/clone" did not exist on "751c3a037cdfa27e58cec5e316b3f23cb0b80db2"
Unverified Commit 9736cd3b authored by Swipe4057's avatar Swipe4057 Committed by GitHub
Browse files

[Bugfix] pipeline parallelism and Eagle Qwen2 (#6910)

parent 2f715f51
......@@ -24,13 +24,14 @@ from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
Qwen2Config = None
......@@ -87,6 +88,7 @@ class Qwen2Model(nn.Module):
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
......@@ -119,6 +121,7 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
nn.Module.__init__(self)
self.config = config
self.quant_config = quant_config
self.pp_group = get_pp_group()
self.model = Qwen2Model(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment