"vscode:/vscode.git/clone" did not exist on "2c7fa47161ba513817a80e165c86a66760c06ebb"
Commit 3ab7f0ef authored by zhuwenwen's avatar zhuwenwen
Browse files

suppoet nn layout

parent 5e766c85
...@@ -33,7 +33,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, ...@@ -33,7 +33,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union) Union)
from unittest.mock import patch from unittest.mock import patch
import flux
import torch import torch
import torch.distributed import torch.distributed
from torch.distributed import Backend, ProcessGroup from torch.distributed import Backend, ProcessGroup
...@@ -208,6 +207,8 @@ class GroupCoordinator: ...@@ -208,6 +207,8 @@ class GroupCoordinator:
self.use_hpu_communicator = use_hpu_communicator self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator self.use_xpu_communicator = use_xpu_communicator
if envs.VLLM_USE_FLUX:
import flux
# Initialize pynvshmem # Initialize pynvshmem
if torch.distributed.get_world_size(self.device_group) > 1: if torch.distributed.get_world_size(self.device_group) > 1:
flux.init_flux_shm(self.device_group) flux.init_flux_shm(self.device_group)
......
...@@ -4,7 +4,7 @@ import itertools ...@@ -4,7 +4,7 @@ import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Optional, List from typing import Optional, List
import flux
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
...@@ -14,7 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank, ...@@ -14,7 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
split_tensor_along_last_dim, split_tensor_along_last_dim,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
...@@ -30,6 +30,12 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -30,6 +30,12 @@ from vllm.model_executor.utils import set_weight_attrs
import os import os
from vllm.model_executor.utils import gemm_bank_conf from vllm.model_executor.utils import gemm_bank_conf
import vllm.envs as envs
if envs.VLLM_USE_FLUX:
import flux
from vllm.distributed.parallel_state import get_tp_group
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -186,14 +192,14 @@ class GemmRS(LinearMethodBase): ...@@ -186,14 +192,14 @@ class GemmRS(LinearMethodBase):
if self.use_llama_nn: if self.use_llama_nn:
self.gemm_rs_op = flux.GemmRS( self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group, get_tp_group().device_group,
nnodes=1, # One node 1, # One node
max_m=8192, # Max M. TODO: Pass in correctly. 8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N output_size, # N
# TODO: Pass in input dtype correctly. # TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype # TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be. # at run time, but I don't know what the downside would be.
# Similar comment for max m. # Similar comment for max m.
input_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed # Note: transpose_weight=False means that B is transposed
transpose_weight=True, transpose_weight=True,
# Note: bfloat16 requires fuse_reduction=False. # Note: bfloat16 requires fuse_reduction=False.
...@@ -202,14 +208,14 @@ class GemmRS(LinearMethodBase): ...@@ -202,14 +208,14 @@ class GemmRS(LinearMethodBase):
else: else:
self.gemm_rs_op = flux.GemmRS( self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group, get_tp_group().device_group,
nnodes=1, # One node 1, # One node
max_m=8192, # Max M. TODO: Pass in correctly. 8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N output_size, # N
# TODO: Pass in input dtype correctly. # TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype # TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be. # at run time, but I don't know what the downside would be.
# Similar comment for max m. # Similar comment for max m.
input_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed # Note: transpose_weight=False means that B is transposed
transpose_weight=False, transpose_weight=False,
# Note: bfloat16 requires fuse_reduction=False. # Note: bfloat16 requires fuse_reduction=False.
...@@ -251,16 +257,16 @@ class AGCook(LinearMethodBase): ...@@ -251,16 +257,16 @@ class AGCook(LinearMethodBase):
if self.use_llama_nn: if self.use_llama_nn:
self.ag_gemm_op = flux.AGKernel( self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group, get_tp_group().device_group,
nnodes=1, # One node 1, # One node
full_m=8192, # Max M. TODO: Pass in correctly. 8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N weight.shape[0], # N
k_dim=weight.shape[1], # K weight.shape[1], # K
# TODO: Pass in input dtype correctly. # TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype # TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be. # at run time, but I don't know what the downside would be.
# Similar comment for max m. # Similar comment for max m.
input_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
output_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed # Note: transpose_weight=False means that B is transposed
transpose_weight=True, transpose_weight=True,
# Note: if local_copy=True, I hit the following runtime error: # Note: if local_copy=True, I hit the following runtime error:
...@@ -272,16 +278,16 @@ class AGCook(LinearMethodBase): ...@@ -272,16 +278,16 @@ class AGCook(LinearMethodBase):
else: else:
self.ag_gemm_op = flux.AGKernel( self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group, get_tp_group().device_group,
nnodes=1, # One node 1, # One node
full_m=8192, # Max M. TODO: Pass in correctly. 8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N weight.shape[0], # N
k_dim=weight.shape[1], # K weight.shape[1], # K
# TODO: Pass in input dtype correctly. # TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype # TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be. # at run time, but I don't know what the downside would be.
# Similar comment for max m. # Similar comment for max m.
input_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
output_dtype=params_dtype, #torch.float16, params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed # Note: transpose_weight=False means that B is transposed
transpose_weight=False, transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error: # Note: if local_copy=True, I hit the following runtime error:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment