Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
23eca9cf
Unverified
Commit
23eca9cf
authored
Feb 24, 2025
by
Mengqing Cao
Committed by
GitHub
Feb 24, 2025
Browse files
[model][refactor] remove cuda hard code in models and layers (#13658)
parent
437b76ff
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
29 additions
and
14 deletions
+29
-14
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+2
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+9
-4
vllm/model_executor/layers/spec_decode_base_sampler.py
vllm/model_executor/layers/spec_decode_base_sampler.py
+3
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+2
-1
vllm/model_executor/models/arctic.py
vllm/model_executor/models/arctic.py
+3
-2
vllm/model_executor/models/minicpm.py
vllm/model_executor/models/minicpm.py
+3
-2
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+7
-3
No files found.
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
23eca9cf
...
...
@@ -7,6 +7,7 @@ import torch
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -238,7 +239,7 @@ def fused_marlin_moe(
max_workspace_size
=
(
max
(
2
*
N
,
K
)
//
64
)
*
16
workspace
=
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
,
device
=
"
cu
da"
,
device
=
cu
rrent_platform
.
device_type
,
requires_grad
=
False
)
if
has_no_zp
:
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
23eca9cf
...
...
@@ -30,6 +30,7 @@ import torch.nn as nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -650,8 +651,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
/
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
current_platform
.
device_type
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
...
...
@@ -670,7 +675,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
device
=
"
cu
da"
,
device
=
cu
rrent_platform
.
device_type
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
(
freqs
.
cos
()
*
self
.
mscale
)
...
...
vllm/model_executor/layers/spec_decode_base_sampler.py
View file @
23eca9cf
...
...
@@ -7,6 +7,8 @@ import torch
import
torch.jit
import
torch.nn
as
nn
from
vllm.platforms
import
current_platform
class
SpecDecodeBaseSampler
(
nn
.
Module
):
"""Base class for samplers used for Speculative Decoding verification
...
...
@@ -35,7 +37,7 @@ class SpecDecodeBaseSampler(nn.Module):
def
init_gpu_tensors
(
self
,
device
:
Union
[
int
,
str
])
->
None
:
assert
self
.
num_accepted_tokens
is
None
if
isinstance
(
device
,
int
):
device
=
f
"cu
da
:
{
device
}
"
device
=
f
"
{
cu
rrent_platform
.
device_type
}
:
{
device
}
"
elif
not
isinstance
(
device
,
str
):
raise
ValueError
(
f
"Device must be int or str, get
{
type
(
device
)
}
"
)
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
...
...
vllm/model_executor/model_loader/loader.py
View file @
23eca9cf
...
...
@@ -914,7 +914,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
return
QuantState
.
from_dict
(
quant_state
,
device
=
"cuda"
)
return
QuantState
.
from_dict
(
quant_state
,
device
=
current_platform
.
device_type
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
...
...
vllm/model_executor/models/arctic.py
View file @
23eca9cf
...
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.arctic
import
ArcticConfig
...
...
@@ -138,13 +139,13 @@ class ArcticMoE(nn.Module):
torch
.
empty
(
self
.
num_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
device
=
"
cu
da"
,
device
=
cu
rrent_platform
.
device_type
,
dtype
=
self
.
params_dtype
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
device
=
"
cu
da"
,
device
=
cu
rrent_platform
.
device_type
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
...
...
vllm/model_executor/models/minicpm.py
View file @
23eca9cf
...
...
@@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsPP
...
...
@@ -98,13 +99,13 @@ class MiniCPMMoE(nn.Module):
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
device
=
"
cu
da"
,
device
=
cu
rrent_platform
.
device_type
,
dtype
=
self
.
params_dtype
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
device
=
"
cu
da"
,
device
=
cu
rrent_platform
.
device_type
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
ws
,
{
...
...
vllm/model_executor/models/minicpmv.py
View file @
23eca9cf
...
...
@@ -59,6 +59,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
.idefics2_vision_model
import
Idefics2VisionTransformer
...
...
@@ -1184,7 +1185,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config
=
quant_config
,
prefix
=
prefix
)
return
resampler
.
to
(
device
=
"cuda"
,
dtype
=
torch
.
get_default_dtype
())
return
resampler
.
to
(
device
=
current_platform
.
device_type
,
dtype
=
torch
.
get_default_dtype
())
def
get_vision_embedding
(
self
,
...
...
@@ -1266,7 +1268,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config
=
quant_config
,
prefix
=
prefix
)
return
resampler
.
to
(
device
=
"cuda"
,
dtype
=
torch
.
get_default_dtype
())
return
resampler
.
to
(
device
=
current_platform
.
device_type
,
dtype
=
torch
.
get_default_dtype
())
def
get_vision_embedding
(
self
,
...
...
@@ -1360,7 +1363,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config
=
quant_config
,
prefix
=
prefix
)
return
resampler
.
to
(
device
=
"cuda"
,
dtype
=
torch
.
get_default_dtype
())
return
resampler
.
to
(
device
=
current_platform
.
device_type
,
dtype
=
torch
.
get_default_dtype
())
def
get_vision_embedding
(
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment