Unverified Commit 5a176c92 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

fix deepseek v2 with cpu device (#2975)

parent 4719c1d0
...@@ -664,6 +664,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -664,6 +664,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow: int = 1, beta_slow: int = 1,
mscale: float = 1, mscale: float = 1,
mscale_all_dim: float = 0, mscale_all_dim: float = 0,
device: Optional[str] = "cuda",
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor
...@@ -676,13 +677,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -676,13 +677,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor * attn_factor
) )
self.device = device
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
) )
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base ** ( pos_freqs = self.base ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device="cuda") torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device)
/ self.rotary_dim / self.rotary_dim
) )
inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_extrapolation = 1.0 / pos_freqs
...@@ -710,7 +712,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -710,7 +712,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
inv_freq = self._compute_inv_freq(self.scaling_factor) inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange( t = torch.arange(
self.max_position_embeddings * self.scaling_factor, self.max_position_embeddings * self.scaling_factor,
device="cuda", device=self.device,
dtype=torch.float32, dtype=torch.float32,
) )
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
...@@ -1174,3 +1176,111 @@ def get_rope( ...@@ -1174,3 +1176,111 @@ def get_rope(
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb
return rotary_emb return rotary_emb
def get_rope_cpu(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling_args,
dtype,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
assert rope_scaling is not None
scaling_type = rope_scaling["rope_type"]
assert (
scaling_type == "deepseek_yarn"
), "Only deepseek_yarn is supported for CPU for now"
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
)
}
extra_kwargs["device"] = device
rotary_emb = DeepseekScalingRotaryEmbedding(
head_size,
rotary_dim,
original_max_position,
base,
is_neox_style,
scaling_factor,
dtype,
**extra_kwargs,
)
_ROPE_DICT[key] = rotary_emb
return rotary_emb
def get_rope_wrapper(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
):
if device != "cpu":
return get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
)
return get_rope_cpu(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
device,
)
...@@ -48,7 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -48,7 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -271,7 +271,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -271,7 +271,7 @@ class DeepseekV2Attention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
) )
rope_scaling["rope_type"] = "deepseek_yarn" rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope( self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
......
...@@ -39,12 +39,12 @@ from PIL import Image ...@@ -39,12 +39,12 @@ from PIL import Image
from torch import nn from torch import nn
from torch.nn.init import trunc_normal_ from torch.nn.init import trunc_normal_
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
......
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment