Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ab7f0ef
Commit
3ab7f0ef
authored
Mar 31, 2025
by
zhuwenwen
Browse files
suppoet nn layout
parent
5e766c85
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
49 deletions
+56
-49
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+5
-4
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+51
-45
No files found.
vllm/distributed/parallel_state.py
View file @
3ab7f0ef
...
...
@@ -33,7 +33,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union
)
from
unittest.mock
import
patch
import
flux
import
torch
import
torch.distributed
from
torch.distributed
import
Backend
,
ProcessGroup
...
...
@@ -208,6 +207,8 @@ class GroupCoordinator:
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
if
envs
.
VLLM_USE_FLUX
:
import
flux
# Initialize pynvshmem
if
torch
.
distributed
.
get_world_size
(
self
.
device_group
)
>
1
:
flux
.
init_flux_shm
(
self
.
device_group
)
...
...
vllm/model_executor/layers/linear.py
View file @
3ab7f0ef
...
...
@@ -4,7 +4,7 @@ import itertools
from
abc
import
abstractmethod
from
typing
import
Optional
,
List
import
flux
import
torch
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
...
...
@@ -14,7 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
...
...
@@ -30,6 +30,12 @@ from vllm.model_executor.utils import set_weight_attrs
import
os
from
vllm.model_executor.utils
import
gemm_bank_conf
import
vllm.envs
as
envs
if
envs
.
VLLM_USE_FLUX
:
import
flux
from
vllm.distributed.parallel_state
import
get_tp_group
logger
=
init_logger
(
__name__
)
...
...
@@ -186,14 +192,14 @@ class GemmRS(LinearMethodBase):
if
self
.
use_llama_nn
:
self
.
gemm_rs_op
=
flux
.
GemmRS
(
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
max_m
=
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
output_size
,
# N
1
,
# One node
8192
,
# Max M. TODO: Pass in correctly.
output_size
,
# N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
True
,
# Note: bfloat16 requires fuse_reduction=False.
...
...
@@ -202,14 +208,14 @@ class GemmRS(LinearMethodBase):
else
:
self
.
gemm_rs_op
=
flux
.
GemmRS
(
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
max_m
=
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
output_size
,
# N
1
,
# One node
8192
,
# Max M. TODO: Pass in correctly.
output_size
,
# N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
False
,
# Note: bfloat16 requires fuse_reduction=False.
...
...
@@ -251,16 +257,16 @@ class AGCook(LinearMethodBase):
if
self
.
use_llama_nn
:
self
.
ag_gemm_op
=
flux
.
AGKernel
(
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
full_m
=
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
weight
.
shape
[
0
],
# N
k_dim
=
weight
.
shape
[
1
],
# K
1
,
# One node
8192
,
# Max M. TODO: Pass in correctly.
weight
.
shape
[
0
],
# N
weight
.
shape
[
1
],
# K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
output_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
params_dtype
,
#
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
True
,
# Note: if local_copy=True, I hit the following runtime error:
...
...
@@ -272,16 +278,16 @@ class AGCook(LinearMethodBase):
else
:
self
.
ag_gemm_op
=
flux
.
AGKernel
(
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
full_m
=
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
weight
.
shape
[
0
],
# N
k_dim
=
weight
.
shape
[
1
],
# K
1
,
# One node
8192
,
# Max M. TODO: Pass in correctly.
weight
.
shape
[
0
],
# N
weight
.
shape
[
1
],
# K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
output_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
params_dtype
,
#
torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
False
,
# Note: if local_copy=True, I hit the following runtime error:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment