Commit 5e766c85 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Kernel] Prototype integration of flux kernels

parent feeb058b
......@@ -33,6 +33,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
Union)
from unittest.mock import patch
import flux
import torch
import torch.distributed
from torch.distributed import Backend, ProcessGroup
......@@ -206,6 +207,10 @@ class GroupCoordinator:
self.use_tpu_communicator = use_tpu_communicator
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator
# Initialize pynvshmem
if torch.distributed.get_world_size(self.device_group) > 1:
flux.init_flux_shm(self.device_group)
# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
......@@ -1282,4 +1287,4 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
aggregated_data += rank_data
return [x == 1 for x in aggregated_data.tolist()]
return [x == 1 for x in aggregated_data.tolist()]
\ No newline at end of file
......@@ -2,8 +2,9 @@
import itertools
from abc import abstractmethod
from typing import Optional
from typing import Optional, List
import flux
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter
......@@ -13,6 +14,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
......@@ -161,6 +163,144 @@ class UnquantizedLinearMethod(LinearMethodBase):
return F.linear(x, layer.weight, bias)
class GemmRS(LinearMethodBase):
#Fused Gemm-ReduceScatter without quantization.
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
if self.use_llama_nn:
self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group,
nnodes=1, # One node
max_m=8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=True,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False,
)
else:
self.gemm_rs_op = flux.GemmRS(
get_tp_group().device_group,
nnodes=1, # One node
max_m=8192, # Max M. TODO: Pass in correctly.
n_dim=output_size, # N
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: bfloat16 requires fuse_reduction=False.
fuse_reduction=False,
)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None
output = self.gemm_rs_op.forward(x, layer.weight)
output = output.squeeze(0)
return output
class AGCook(LinearMethodBase):
#Fused AllGather-Gemm without quantization.
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
if self.use_llama_nn:
self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
nnodes=1, # One node
full_m=8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N
k_dim=weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
output_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=True,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)
else:
self.ag_gemm_op = flux.AGKernel(
get_tp_group().device_group,
nnodes=1, # One node
full_m=8192, # Max M. TODO: Pass in correctly.
n_dim=weight.shape[0], # N
k_dim=weight.shape[1], # K
# TODO: Pass in input dtype correctly.
# TODO: It would be nicer to modify flux to dispatch based on dtype
# at run time, but I don't know what the downside would be.
# Similar comment for max m.
input_dtype=params_dtype, #torch.float16,
output_dtype=params_dtype, #torch.float16,
# Note: transpose_weight=False means that B is transposed
transpose_weight=False,
# Note: if local_copy=True, I hit the following runtime error:
# /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648
# Check failed: 33554432((input.numel() * input.element_size()))
# == 139836453421056((this->chunk_size))
local_copy=False,
)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None
output = self.ag_gemm_op.forward(x, layer.weight)
return output
class LinearBase(torch.nn.Module):
"""Base linear layer.
......@@ -181,6 +321,8 @@ class LinearBase(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
fuse_gemm_rs: bool = False,
fuse_ag_gemm: bool = False,
):
super().__init__()
......@@ -191,9 +333,14 @@ class LinearBase(torch.nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
if fuse_gemm_rs:
assert (quant_config is None)
self.quant_method: Optional[QuantizeMethodBase] = GemmRS()
elif fuse_ag_gemm:
assert (quant_config is None)
self.quant_method = AGCook()
elif quant_config is None:
self.quant_method = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
......@@ -308,9 +455,10 @@ class ColumnParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = ""):
prefix: str = "",
fuse_ag_gemm: bool = False):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
quant_config, prefix, fuse_ag_gemm=fuse_ag_gemm)
self.gather_output = gather_output
......@@ -447,7 +595,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
fuse_ag_gemm: bool = False):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
......@@ -458,7 +607,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
fuse_ag_gemm=fuse_ag_gemm)
def weight_loader(self,
param: Parameter,
......@@ -723,7 +873,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
fuse_ag_gemm: bool = False):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
......@@ -756,7 +907,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix)
prefix=prefix,
fuse_ag_gemm=fuse_ag_gemm)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
......@@ -1060,12 +1212,15 @@ class RowParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
prefix: str = "",
fuse_gemm_rs: bool = False):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
quant_config, prefix, fuse_gemm_rs=fuse_gemm_rs)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if fuse_gemm_rs:
self.reduce_results = False
# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
......
......@@ -33,7 +33,7 @@ import vllm.envs as envs
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank, tensor_model_parallel_all_gather
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
......@@ -71,6 +71,7 @@ class LlamaMLP(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
last_layer: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
......@@ -79,6 +80,7 @@ class LlamaMLP(nn.Module):
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
fuse_ag_gemm=True,
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
......@@ -86,6 +88,7 @@ class LlamaMLP(nn.Module):
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
fuse_gemm_rs=(not last_layer),
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
......@@ -106,6 +109,7 @@ class LlamaAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
first_layer: bool,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
......@@ -148,6 +152,7 @@ class LlamaAttention(nn.Module):
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
fuse_ag_gemm=(not first_layer),
)
self.o_proj = RowParallelLinear(
......@@ -156,6 +161,7 @@ class LlamaAttention(nn.Module):
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
fuse_gemm_rs=True,
)
is_neox_style = True
......@@ -223,6 +229,11 @@ class LlamaDecoderLayer(nn.Module):
def __init__(
self,
config: LlamaConfig,
# Hack: pass in whether this is the first/last layer
# so we know if we can rewrite AllReduce -> ReduceScatter + AllGather,
# and then propagate the AllGather to the next layer.
first_layer: bool,
last_layer: bool,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
......@@ -252,6 +263,7 @@ class LlamaDecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
first_layer=first_layer,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
......@@ -268,11 +280,15 @@ class LlamaDecoderLayer(nn.Module):
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
last_layer=last_layer,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.first_layer = first_layer
self.last_layer = last_layer
def forward(
self,
......@@ -287,17 +303,35 @@ class LlamaDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
assert (hidden_states.shape == residual.shape)
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
# Partition residual
if self.first_layer:
n_slices = get_tensor_model_parallel_world_size()
residual_slices = torch.chunk(residual, n_slices, dim=0)
my_residual = residual_slices[get_tensor_model_parallel_rank()]
else:
my_residual = residual
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
assert (hidden_states.shape == my_residual.shape)
hidden_states, my_residual = self.post_attention_layernorm(
hidden_states, my_residual)
hidden_states = self.mlp(hidden_states)
if self.last_layer:
residual = tensor_model_parallel_all_gather(my_residual, 0)
else:
residual = my_residual
assert (hidden_states.shape == residual.shape)
return hidden_states, residual
......@@ -335,7 +369,9 @@ class LlamaModel(nn.Module):
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: layer_type(config=config,
lambda prefix, first_layer, last_layer: layer_type(config=config,
first_layer=first_layer,
last_layer=last_layer,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
......@@ -763,4 +799,4 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight
return name, loaded_weight
\ No newline at end of file
......@@ -548,14 +548,31 @@ def make_layers(
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
"""
import inspect
from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.utils import get_pp_indices
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
# Determine if layer_fn accepts first/last args by inspecting its signature
sig = inspect.signature(layer_fn)
has_firstlast_args = ('first_layer'
in sig.parameters) and ('last_layer'
in sig.parameters)
def make_one_layer(idx, start_layer, end_layer):
if has_firstlast_args:
return maybe_offload_to_cpu(
layer_fn(prefix=f"{prefix}.{idx}",
first_layer=(idx == start_layer),
last_layer=(idx == end_layer - 1)))
else:
return maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] + [
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
make_one_layer(idx, start_layer, end_layer)
for idx in range(start_layer, end_layer)
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules
......@@ -640,4 +657,4 @@ def extract_layer_index(layer_name: str) -> int:
continue
assert len(int_vals) == 1, (f"layer name {layer_name} should"
" only contain one integer")
return int_vals[0]
return int_vals[0]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment