Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5e766c85
Commit
5e766c85
authored
Mar 28, 2025
by
zhuwenwen
Browse files
[Kernel] Prototype integration of flux kernels
parent
feeb058b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
233 additions
and
20 deletions
+233
-20
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+6
-1
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+167
-12
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+41
-5
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+19
-2
No files found.
vllm/distributed/parallel_state.py
View file @
5e766c85
...
@@ -33,6 +33,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
...
@@ -33,6 +33,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union
)
Union
)
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
flux
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
torch.distributed
import
Backend
,
ProcessGroup
from
torch.distributed
import
Backend
,
ProcessGroup
...
@@ -207,6 +208,10 @@ class GroupCoordinator:
...
@@ -207,6 +208,10 @@ class GroupCoordinator:
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
# Initialize pynvshmem
if
torch
.
distributed
.
get_world_size
(
self
.
device_group
)
>
1
:
flux
.
init_flux_shm
(
self
.
device_group
)
# lazy import to avoid documentation build error
# lazy import to avoid documentation build error
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
)
CustomAllreduce
)
...
...
vllm/model_executor/layers/linear.py
View file @
5e766c85
...
@@ -2,8 +2,9 @@
...
@@ -2,8 +2,9 @@
import
itertools
import
itertools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Optional
from
typing
import
Optional
,
List
import
flux
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
...
@@ -13,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -13,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
...
@@ -161,6 +163,144 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -161,6 +163,144 @@ class UnquantizedLinearMethod(LinearMethodBase):
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
class
GemmRS
(
LinearMethodBase
):
#Fused Gemm-ReduceScatter without quantization.
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
if
self
.
use_llama_nn
:
self
.
gemm_rs_op
=
flux
.
GemmRS
(
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
max_m
=
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
output_size
,
# N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
True
,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction
=
False
,
)
else
:
self
.
gemm_rs_op
=
flux
.
GemmRS
(
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
max_m
=
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
output_size
,
# N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
False
,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction
=
False
,
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
bias
is
None
output
=
self
.
gemm_rs_op
.
forward
(
x
,
layer
.
weight
)
output
=
output
.
squeeze
(
0
)
return
output
class
AGCook
(
LinearMethodBase
):
#Fused AllGather-Gemm without quantization.
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
if
self
.
use_llama_nn
:
self
.
ag_gemm_op
=
flux
.
AGKernel
(
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
full_m
=
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
weight
.
shape
[
0
],
# N
k_dim
=
weight
.
shape
[
1
],
# K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
output_dtype
=
params_dtype
,
#torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
True
,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy
=
False
,
)
else
:
self
.
ag_gemm_op
=
flux
.
AGKernel
(
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
full_m
=
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
weight
.
shape
[
0
],
# N
k_dim
=
weight
.
shape
[
1
],
# K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
output_dtype
=
params_dtype
,
#torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
False
,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy
=
False
,
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
assert
bias
is
None
output
=
self
.
ag_gemm_op
.
forward
(
x
,
layer
.
weight
)
return
output
class
LinearBase
(
torch
.
nn
.
Module
):
class
LinearBase
(
torch
.
nn
.
Module
):
"""Base linear layer.
"""Base linear layer.
...
@@ -181,6 +321,8 @@ class LinearBase(torch.nn.Module):
...
@@ -181,6 +321,8 @@ class LinearBase(torch.nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
fuse_gemm_rs
:
bool
=
False
,
fuse_ag_gemm
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -191,9 +333,14 @@ class LinearBase(torch.nn.Module):
...
@@ -191,9 +333,14 @@ class LinearBase(torch.nn.Module):
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
params_dtype
=
params_dtype
if
quant_config
is
None
:
if
fuse_gemm_rs
:
self
.
quant_method
:
Optional
[
assert
(
quant_config
is
None
)
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
GemmRS
()
elif
fuse_ag_gemm
:
assert
(
quant_config
is
None
)
self
.
quant_method
=
AGCook
()
elif
quant_config
is
None
:
self
.
quant_method
=
UnquantizedLinearMethod
()
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
prefix
=
prefix
)
...
@@ -308,9 +455,10 @@ class ColumnParallelLinear(LinearBase):
...
@@ -308,9 +455,10 @@ class ColumnParallelLinear(LinearBase):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
fuse_ag_gemm
:
bool
=
False
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
quant_config
,
prefix
,
fuse_ag_gemm
=
fuse_ag_gemm
)
self
.
gather_output
=
gather_output
self
.
gather_output
=
gather_output
...
@@ -447,7 +595,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -447,7 +595,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
fuse_ag_gemm
:
bool
=
False
):
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
...
@@ -458,7 +607,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -458,7 +607,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
prefix
=
prefix
,
fuse_ag_gemm
=
fuse_ag_gemm
)
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -723,7 +873,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -723,7 +873,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
fuse_ag_gemm
:
bool
=
False
):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
self
.
total_num_heads
=
total_num_heads
...
@@ -756,7 +907,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -756,7 +907,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
prefix
=
prefix
,
fuse_ag_gemm
=
fuse_ag_gemm
)
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
def
_get_shard_offset_mapping
(
self
,
loaded_shard_id
:
str
):
shard_offset_mapping
=
{
shard_offset_mapping
=
{
...
@@ -1060,12 +1212,15 @@ class RowParallelLinear(LinearBase):
...
@@ -1060,12 +1212,15 @@ class RowParallelLinear(LinearBase):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
prefix
:
str
=
""
,
fuse_gemm_rs
:
bool
=
False
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
quant_config
,
prefix
,
fuse_gemm_rs
=
fuse_gemm_rs
)
self
.
input_is_parallel
=
input_is_parallel
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
if
fuse_gemm_rs
:
self
.
reduce_results
=
False
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
...
...
vllm/model_executor/models/llama.py
View file @
5e766c85
...
@@ -33,7 +33,7 @@ import vllm.envs as envs
...
@@ -33,7 +33,7 @@ import vllm.envs as envs
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.config
import
CacheConfig
,
VllmConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_rank
,
tensor_model_parallel_all_gather
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
@@ -71,6 +71,7 @@ class LlamaMLP(nn.Module):
...
@@ -71,6 +71,7 @@ class LlamaMLP(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
last_layer
:
bool
=
False
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
...
@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
...
@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
fuse_ag_gemm
=
True
,
)
)
self
.
down_proj
=
RowParallelLinear
(
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
input_size
=
intermediate_size
,
...
@@ -86,6 +88,7 @@ class LlamaMLP(nn.Module):
...
@@ -86,6 +88,7 @@ class LlamaMLP(nn.Module):
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
fuse_gemm_rs
=
(
not
last_layer
),
)
)
if
hidden_act
!=
"silu"
:
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
...
@@ -106,6 +109,7 @@ class LlamaAttention(nn.Module):
...
@@ -106,6 +109,7 @@ class LlamaAttention(nn.Module):
hidden_size
:
int
,
hidden_size
:
int
,
num_heads
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
first_layer
:
bool
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
...
@@ -148,6 +152,7 @@ class LlamaAttention(nn.Module):
...
@@ -148,6 +152,7 @@ class LlamaAttention(nn.Module):
bias
=
bias
,
bias
=
bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
fuse_ag_gemm
=
(
not
first_layer
),
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
...
@@ -156,6 +161,7 @@ class LlamaAttention(nn.Module):
...
@@ -156,6 +161,7 @@ class LlamaAttention(nn.Module):
bias
=
bias_o_proj
,
bias
=
bias_o_proj
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
fuse_gemm_rs
=
True
,
)
)
is_neox_style
=
True
is_neox_style
=
True
...
@@ -223,6 +229,11 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -223,6 +229,11 @@ class LlamaDecoderLayer(nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
config
:
LlamaConfig
,
config
:
LlamaConfig
,
# Hack: pass in whether this is the first/last layer
# so we know if we can rewrite AllReduce -> ReduceScatter + AllGather,
# and then propagate the AllGather to the next layer.
first_layer
:
bool
,
last_layer
:
bool
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
...
@@ -252,6 +263,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -252,6 +263,7 @@ class LlamaDecoderLayer(nn.Module):
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
config
.
num_attention_heads
),
config
.
num_attention_heads
),
first_layer
=
first_layer
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
...
@@ -268,12 +280,16 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -268,12 +280,16 @@ class LlamaDecoderLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
prefix
=
f
"
{
prefix
}
.mlp"
,
prefix
=
f
"
{
prefix
}
.mlp"
,
last_layer
=
last_layer
,
)
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
eps
=
config
.
rms_norm_eps
)
self
.
first_layer
=
first_layer
self
.
last_layer
=
last_layer
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -287,17 +303,35 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -287,17 +303,35 @@ class LlamaDecoderLayer(nn.Module):
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
else
:
assert
(
hidden_states
.
shape
==
residual
.
shape
)
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
)
# Partition residual
if
self
.
first_layer
:
n_slices
=
get_tensor_model_parallel_world_size
()
residual_slices
=
torch
.
chunk
(
residual
,
n_slices
,
dim
=
0
)
my_residual
=
residual_slices
[
get_tensor_model_parallel_rank
()]
else
:
my_residual
=
residual
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
)
attn_metadata
=
attn_metadata
)
# Fully Connected
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
assert
(
hidden_states
.
shape
==
my_residual
.
shape
)
hidden_states
,
residual
)
hidden_states
,
my_residual
=
self
.
post_attention_layernorm
(
hidden_states
,
my_residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
if
self
.
last_layer
:
residual
=
tensor_model_parallel_all_gather
(
my_residual
,
0
)
else
:
residual
=
my_residual
assert
(
hidden_states
.
shape
==
residual
.
shape
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -335,7 +369,9 @@ class LlamaModel(nn.Module):
...
@@ -335,7 +369,9 @@ class LlamaModel(nn.Module):
self
.
embed_tokens
=
PPMissingLayer
()
self
.
embed_tokens
=
PPMissingLayer
()
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
layer_type
(
config
=
config
,
lambda
prefix
,
first_layer
,
last_layer
:
layer_type
(
config
=
config
,
first_layer
=
first_layer
,
last_layer
=
last_layer
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
prefix
),
...
...
vllm/model_executor/models/utils.py
View file @
5e766c85
...
@@ -548,14 +548,31 @@ def make_layers(
...
@@ -548,14 +548,31 @@ def make_layers(
"""Make a list of layers with the given layer function, taking
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
pipeline parallelism into account.
"""
"""
import
inspect
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.distributed.utils
import
get_pp_indices
start_layer
,
end_layer
=
get_pp_indices
(
num_hidden_layers
,
start_layer
,
end_layer
=
get_pp_indices
(
num_hidden_layers
,
get_pp_group
().
rank_in_group
,
get_pp_group
().
rank_in_group
,
get_pp_group
().
world_size
)
get_pp_group
().
world_size
)
# Determine if layer_fn accepts first/last args by inspecting its signature
sig
=
inspect
.
signature
(
layer_fn
)
has_firstlast_args
=
(
'first_layer'
in
sig
.
parameters
)
and
(
'last_layer'
in
sig
.
parameters
)
def
make_one_layer
(
idx
,
start_layer
,
end_layer
):
if
has_firstlast_args
:
return
maybe_offload_to_cpu
(
layer_fn
(
prefix
=
f
"
{
prefix
}
.
{
idx
}
"
,
first_layer
=
(
idx
==
start_layer
),
last_layer
=
(
idx
==
end_layer
-
1
)))
else
:
return
maybe_offload_to_cpu
(
layer_fn
(
prefix
=
f
"
{
prefix
}
.
{
idx
}
"
))
modules
=
torch
.
nn
.
ModuleList
(
modules
=
torch
.
nn
.
ModuleList
(
[
PPMissingLayer
()
for
_
in
range
(
start_layer
)]
+
[
[
PPMissingLayer
()
for
_
in
range
(
start_layer
)]
+
[
ma
yb
e_o
ffload_to_cpu
(
layer_fn
(
prefix
=
f
"
{
prefix
}
.
{
idx
}
"
)
)
ma
k
e_o
ne_layer
(
idx
,
start_layer
,
end_layer
)
for
idx
in
range
(
start_layer
,
end_layer
)
for
idx
in
range
(
start_layer
,
end_layer
)
]
+
[
PPMissingLayer
()
for
_
in
range
(
end_layer
,
num_hidden_layers
)])
]
+
[
PPMissingLayer
()
for
_
in
range
(
end_layer
,
num_hidden_layers
)])
return
start_layer
,
end_layer
,
modules
return
start_layer
,
end_layer
,
modules
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment