"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "1a956e136beae057746af6257ffa8da601730f10"
Commit 3ab7f0ef authored by zhuwenwen's avatar zhuwenwen
Browse files

suppoet nn layout

parent 5e766c85
...@@ -33,7 +33,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, ...@@ -33,7 +33,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union) Union)
from unittest.mock import patch from unittest.mock import patch
import flux
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
...@@ -208,9 +207,11 @@ class GroupCoordinator: ...@@ -208,9 +207,11 @@ class GroupCoordinator:
self.use_hpu_communicator = use_hpu_communicator self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator self.use_xpu_communicator = use_xpu_communicator
# Initialize pynvshmem if envs.VLLM_USE_FLUX:
if torch.distributed.get_world_size(self.device_group) > 1: import flux
flux.init_flux_shm(self.device_group) # Initialize pynvshmem
if torch.distributed.get_world_size(self.device_group) > 1:
flux.init_flux_shm(self.device_group)
# lazy import to avoid documentation build error # lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import ( from vllm.distributed.device_communicators.custom_all_reduce import (
......
...@@ -4,7 +4,7 @@ import itertools ...@@ -4,7 +4,7 @@ import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Optional, List from typing import Optional, List
import flux
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
...@@ -14,7 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -14,7 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
...@@ -30,6 +30,12 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -30,6 +30,12 @@ from vllm.model_executor.utils import set_weight_attrs
import os import os
from vllm.model_executor.utils import gemm_bank_conf from vllm.model_executor.utils import gemm_bank_conf
import vllm.envs as envs
if envs.VLLM_USE_FLUX:
import flux
from vllm.distributed.parallel_state import get_tp_group
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -186,14 +192,14 @@ class GemmRS(LinearMethodBase): ...@@ -186,14 +192,14 @@ class GemmRS(LinearMethodBase):
if self.use_llama_nn: if self.use_llama_nn:
self.gemm_rs_op = flux.GemmRS( self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group, get_tp_group().device_group,
nnodes=1, # One node 1, # One node
max_m=8192, # Max M. TODO: Pass in correctly. 8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N output_size, # N
# TODO: Pass in input dtype correctly. # TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype # TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be. # at run time, but I don't know what the downside would be.
# Similar comment for max m. # Similar comment for max m.
input_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed # Note: transpose_weight=False means that B is transposed
transpose_weight=True, transpose_weight=True,
# Note: bfloat16 requires fuse_reduction=False. # Note: bfloat16 requires fuse_reduction=False.
...@@ -201,20 +207,20 @@ class GemmRS(LinearMethodBase): ...@@ -201,20 +207,20 @@ class GemmRS(LinearMethodBase):
) )
else: else:
self.gemm_rs_op = flux.GemmRS( self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group, get_tp_group().device_group,
nnodes=1, # One node 1, # One node
max_m=8192, # Max M. TODO: Pass in correctly. 8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N output_size, # N
# TODO: Pass in input dtype correctly. # TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype # TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be. # at run time, but I don't know what the downside would be.
# Similar comment for max m. # Similar comment for max m.
input_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed # Note: transpose_weight=False means that B is transposed
transpose_weight=False, transpose_weight=False,
# Note: bfloat16 requires fuse_reduction=False. # Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False, fuse_reduction=False,
) )
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -251,16 +257,16 @@ class AGCook(LinearMethodBase): ...@@ -251,16 +257,16 @@ class AGCook(LinearMethodBase):
if self.use_llama_nn: if self.use_llama_nn:
self.ag_gemm_op = flux.AGKernel( self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group, get_tp_group().device_group,
nnodes=1, # One node 1, # One node
full_m=8192, # Max M. TODO: Pass in correctly. 8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N weight.shape[0], # N
k_dim=weight.shape[1], # K weight.shape[1], # K
# TODO: Pass in input dtype correctly. # TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype # TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be. # at run time, but I don't know what the downside would be.
# Similar comment for max m. # Similar comment for max m.
input_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
output_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed # Note: transpose_weight=False means that B is transposed
transpose_weight=True, transpose_weight=True,
# Note: if local_copy=True, I hit the following runtime error: # Note: if local_copy=True, I hit the following runtime error:
...@@ -271,25 +277,25 @@ class AGCook(LinearMethodBase): ...@@ -271,25 +277,25 @@ class AGCook(LinearMethodBase):
) )
else: else:
self.ag_gemm_op = flux.AGKernel( self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group, get_tp_group().device_group,
nnodes=1, # One node 1, # One node
full_m=8192, # Max M. TODO: Pass in correctly. 8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N weight.shape[0], # N
k_dim=weight.shape[1], # K weight.shape[1], # K
# TODO: Pass in input dtype correctly. # TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype # TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be. # at run time, but I don't know what the downside would be.
# Similar comment for max m. # Similar comment for max m.
input_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
output_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed # Note: transpose_weight=False means that B is transposed
transpose_weight=False, transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error: # Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648 # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size())) # Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size)) # == 139836453421056((this->chunk_size))
local_copy=False, local_copy=False,
) )
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment