Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3ab7f0ef
"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "1a956e136beae057746af6257ffa8da601730f10"
Commit
3ab7f0ef
authored
Mar 31, 2025
by
zhuwenwen
Browse files
suppoet nn layout
parent
5e766c85
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
49 deletions
+56
-49
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+5
-4
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+51
-45
No files found.
vllm/distributed/parallel_state.py
View file @
3ab7f0ef
...
@@ -33,7 +33,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
...
@@ -33,7 +33,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union
)
Union
)
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
flux
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
torch.distributed
import
Backend
,
ProcessGroup
from
torch.distributed
import
Backend
,
ProcessGroup
...
@@ -208,9 +207,11 @@ class GroupCoordinator:
...
@@ -208,9 +207,11 @@ class GroupCoordinator:
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_hpu_communicator
=
use_hpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
self
.
use_xpu_communicator
=
use_xpu_communicator
# Initialize pynvshmem
if
envs
.
VLLM_USE_FLUX
:
if
torch
.
distributed
.
get_world_size
(
self
.
device_group
)
>
1
:
import
flux
flux
.
init_flux_shm
(
self
.
device_group
)
# Initialize pynvshmem
if
torch
.
distributed
.
get_world_size
(
self
.
device_group
)
>
1
:
flux
.
init_flux_shm
(
self
.
device_group
)
# lazy import to avoid documentation build error
# lazy import to avoid documentation build error
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
...
...
vllm/model_executor/layers/linear.py
View file @
3ab7f0ef
...
@@ -4,7 +4,7 @@ import itertools
...
@@ -4,7 +4,7 @@ import itertools
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
import
flux
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
from
torch.nn.parameter
import
Parameter
,
UninitializedParameter
...
@@ -14,7 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -14,7 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.distributed.parallel_state
import
get_tp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
...
@@ -30,6 +30,12 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -30,6 +30,12 @@ from vllm.model_executor.utils import set_weight_attrs
import
os
import
os
from
vllm.model_executor.utils
import
gemm_bank_conf
from
vllm.model_executor.utils
import
gemm_bank_conf
import
vllm.envs
as
envs
if
envs
.
VLLM_USE_FLUX
:
import
flux
from
vllm.distributed.parallel_state
import
get_tp_group
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -186,14 +192,14 @@ class GemmRS(LinearMethodBase):
...
@@ -186,14 +192,14 @@ class GemmRS(LinearMethodBase):
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
self
.
gemm_rs_op
=
flux
.
GemmRS
(
self
.
gemm_rs_op
=
flux
.
GemmRS
(
get_tp_group
().
device_group
,
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
1
,
# One node
max_m
=
8192
,
# Max M. TODO: Pass in correctly.
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
output_size
,
# N
output_size
,
# N
# TODO: Pass in input dtype correctly.
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
# Note: transpose_weight=False means that B is transposed
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
True
,
transpose_weight
=
True
,
# Note: bfloat16 requires fuse_reduction=False.
# Note: bfloat16 requires fuse_reduction=False.
...
@@ -201,20 +207,20 @@ class GemmRS(LinearMethodBase):
...
@@ -201,20 +207,20 @@ class GemmRS(LinearMethodBase):
)
)
else
:
else
:
self
.
gemm_rs_op
=
flux
.
GemmRS
(
self
.
gemm_rs_op
=
flux
.
GemmRS
(
get_tp_group
().
device_group
,
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
1
,
# One node
max_m
=
8192
,
# Max M. TODO: Pass in correctly.
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
output_size
,
# N
output_size
,
# N
# TODO: Pass in input dtype correctly.
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
# Note: transpose_weight=False means that B is transposed
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
False
,
transpose_weight
=
False
,
# Note: bfloat16 requires fuse_reduction=False.
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction
=
False
,
fuse_reduction
=
False
,
)
)
def
apply
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -251,16 +257,16 @@ class AGCook(LinearMethodBase):
...
@@ -251,16 +257,16 @@ class AGCook(LinearMethodBase):
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
self
.
ag_gemm_op
=
flux
.
AGKernel
(
self
.
ag_gemm_op
=
flux
.
AGKernel
(
get_tp_group
().
device_group
,
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
1
,
# One node
full_m
=
8192
,
# Max M. TODO: Pass in correctly.
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
weight
.
shape
[
0
],
# N
weight
.
shape
[
0
],
# N
k_dim
=
weight
.
shape
[
1
],
# K
weight
.
shape
[
1
],
# K
# TODO: Pass in input dtype correctly.
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
output_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
# Note: transpose_weight=False means that B is transposed
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
True
,
transpose_weight
=
True
,
# Note: if local_copy=True, I hit the following runtime error:
# Note: if local_copy=True, I hit the following runtime error:
...
@@ -271,25 +277,25 @@ class AGCook(LinearMethodBase):
...
@@ -271,25 +277,25 @@ class AGCook(LinearMethodBase):
)
)
else
:
else
:
self
.
ag_gemm_op
=
flux
.
AGKernel
(
self
.
ag_gemm_op
=
flux
.
AGKernel
(
get_tp_group
().
device_group
,
get_tp_group
().
device_group
,
nnodes
=
1
,
# One node
1
,
# One node
full_m
=
8192
,
# Max M. TODO: Pass in correctly.
8192
,
# Max M. TODO: Pass in correctly.
n_dim
=
weight
.
shape
[
0
],
# N
weight
.
shape
[
0
],
# N
k_dim
=
weight
.
shape
[
1
],
# K
weight
.
shape
[
1
],
# K
# TODO: Pass in input dtype correctly.
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
# Similar comment for max m.
input_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
output_dtype
=
params_dtype
,
#torch.float16,
params_dtype
,
#
torch.float16,
# Note: transpose_weight=False means that B is transposed
# Note: transpose_weight=False means that B is transposed
transpose_weight
=
False
,
transpose_weight
=
False
,
# Note: if local_copy=True, I hit the following runtime error:
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
# == 139836453421056((this->chunk_size))
local_copy
=
False
,
local_copy
=
False
,
)
)
def
apply
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment