Commit 3ab7f0ef authored by zhuwenwen's avatar zhuwenwen
Browse files

suppoet nn layout

parent 5e766c85
......@@ -33,7 +33,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
from unittest.mock import patch
import flux
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
......@@ -208,9 +207,11 @@ class GroupCoordinator:
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
# Initialize pynvshmem
if torch.distributed.get_world_size(self.device_group) > 1:
flux.init_flux_shm(self.device_group)
if envs.VLLM_USE_FLUX:
import flux
# Initialize pynvshmem
if torch.distributed.get_world_size(self.device_group) > 1:
flux.init_flux_shm(self.device_group)
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
......
......@@ -4,7 +4,7 @@ import itertools
from abc import abstractmethod
from typing import Optional, List
import flux
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
......@@ -14,7 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
......@@ -30,6 +30,12 @@ from vllm.model_executor.utils import set_weight_attrs
import os
from vllm.model_executor.utils import gemm_bank_conf
import vllm.envs as envs
if envs.VLLM_USE_FLUX:
import flux
from vllm.distributed.parallel_state import get_tp_group
logger = init_logger(__name__)
......@@ -186,14 +192,14 @@ class GemmRS(LinearMethodBase):
if self.use_llama_nn:
self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group,
nnodes=1, # One node
max_m=8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N
1, # One node
8192, # Max M. TODO: Pass in correctly.
output_size, # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=True,
# Note: bfloat16 requires fuse_reduction=False.
......@@ -201,20 +207,20 @@ class GemmRS(LinearMethodBase):
)
else:
self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group,
nnodes=1, # One node
max_m=8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False,
)
get_tp_group().device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
output_size, # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False,
)
def apply(self,
layer: torch.nn.Module,
......@@ -251,16 +257,16 @@ class AGCook(LinearMethodBase):
if self.use_llama_nn:
self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
nnodes=1, # One node
full_m=8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N
k_dim=weight.shape[1], # K
1, # One node
8192, # Max M. TODO: Pass in correctly.
weight.shape[0], # N
weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
output_dtype=params_dtype, #torch.float16,
params_dtype, # torch.float16,
params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=True,
# Note: if local_copy=True, I hit the following runtime error:
......@@ -271,25 +277,25 @@ class AGCook(LinearMethodBase):
)
else:
self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
nnodes=1, # One node
full_m=8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N
k_dim=weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
output_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)
get_tp_group().device_group,
1, # One node
8192, # Max M. TODO: Pass in correctly.
weight.shape[0], # N
weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
params_dtype, # torch.float16,
params_dtype, # torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)
def apply(self,
layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment