Unverified Commit 1b0bd0fe authored by Zhuohan Li's avatar Zhuohan Li Committed by GitHub
Browse files

Add Falcon support (new) (#592)

parent 20044cab
...@@ -44,6 +44,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit ...@@ -44,6 +44,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
- Baichuan-7B (`baichuan-inc/Baichuan-7B`) - Baichuan-7B (`baichuan-inc/Baichuan-7B`)
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.) - GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
......
...@@ -10,7 +10,8 @@ __global__ void rotary_embedding_neox_kernel( ...@@ -10,7 +10,8 @@ __global__ void rotary_embedding_neox_kernel(
scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size] scalar_t* __restrict__ key, // [num_tokens, num_kv_heads, head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
const int rot_dim, const int rot_dim,
const int stride, const int query_stride,
const int key_stride,
const int num_heads, const int num_heads,
const int num_kv_heads, const int num_kv_heads,
const int head_size) { const int head_size) {
...@@ -23,14 +24,14 @@ __global__ void rotary_embedding_neox_kernel( ...@@ -23,14 +24,14 @@ __global__ void rotary_embedding_neox_kernel(
const int nq = num_heads * embed_dim; const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) { for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim; const int head_idx = i / embed_dim;
const int token_head = token_idx * stride + head_idx * head_size; const int token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim; const int rot_offset = i % embed_dim;
const int x_index = rot_offset; const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset; const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * stride + head_idx * head_size + x_index; const int out_x = token_idx * query_stride + head_idx * head_size + x_index;
const int out_y = token_idx * stride + head_idx * head_size + y_index; const int out_y = token_idx * query_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index); const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index); const scalar_t sin = __ldg(cache_ptr + y_index);
...@@ -39,14 +40,28 @@ __global__ void rotary_embedding_neox_kernel( ...@@ -39,14 +40,28 @@ __global__ void rotary_embedding_neox_kernel(
const scalar_t q_y = query[token_head + y_index]; const scalar_t q_y = query[token_head + y_index];
query[out_x] = q_x * cos - q_y * sin; query[out_x] = q_x * cos - q_y * sin;
query[out_y] = q_y * cos + q_x * sin; query[out_y] = q_y * cos + q_x * sin;
}
const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
const int x_index = rot_offset;
const int y_index = embed_dim + rot_offset;
const int out_x = token_idx * key_stride + head_idx * head_size + x_index;
const int out_y = token_idx * key_stride + head_idx * head_size + y_index;
const scalar_t cos = __ldg(cache_ptr + x_index);
const scalar_t sin = __ldg(cache_ptr + y_index);
if (head_idx < num_kv_heads) {
const scalar_t k_x = key[token_head + x_index]; const scalar_t k_x = key[token_head + x_index];
const scalar_t k_y = key[token_head + y_index]; const scalar_t k_y = key[token_head + y_index];
key[out_x] = k_x * cos - k_y * sin; key[out_x] = k_x * cos - k_y * sin;
key[out_y] = k_y * cos + k_x * sin; key[out_y] = k_y * cos + k_x * sin;
} }
}
} }
} // namespace vllm } // namespace vllm
...@@ -62,8 +77,8 @@ void rotary_embedding_neox( ...@@ -62,8 +77,8 @@ void rotary_embedding_neox(
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(1) / head_size; int num_heads = query.size(1) / head_size;
int num_kv_heads = key.size(1) / head_size; int num_kv_heads = key.size(1) / head_size;
int stride = query.stride(0); int query_stride = query.stride(0);
TORCH_CHECK(stride == key.stride(0)); int key_stride = key.stride(0);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min(num_heads * rot_dim / 2, 512)); dim3 block(std::min(num_heads * rot_dim / 2, 512));
...@@ -80,7 +95,8 @@ void rotary_embedding_neox( ...@@ -80,7 +95,8 @@ void rotary_embedding_neox(
key.data_ptr<scalar_t>(), key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(),
rot_dim, rot_dim,
stride, query_stride,
key_stride,
num_heads, num_heads,
num_kv_heads, num_kv_heads,
head_size); head_size);
......
...@@ -20,6 +20,9 @@ Alongside each architecture, we include some popular models that use it. ...@@ -20,6 +20,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`BloomForCausalLM` * - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat - BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
* - :code:`GPT2LMHeadModel` * - :code:`GPT2LMHeadModel`
- GPT-2 - GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc. - :code:`gpt2`, :code:`gpt2-xl`, etc.
......
...@@ -10,7 +10,8 @@ def main(args: argparse.Namespace): ...@@ -10,7 +10,8 @@ def main(args: argparse.Namespace):
# Test the following prompts. # Test the following prompts.
test_prompts = [ test_prompts = [
("A robot may not injure a human being", SamplingParams()), ("A robot may not injure a human being",
SamplingParams(temperature=0.0)),
("To be or not to be,", ("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?", ("What is the meaning of life?",
......
...@@ -94,8 +94,13 @@ class ModelConfig: ...@@ -94,8 +94,13 @@ class ModelConfig:
return self.hf_config.hidden_size // self.hf_config.num_attention_heads return self.hf_config.hidden_size // self.hf_config.num_attention_heads
def get_num_heads(self, parallel_config: "ParallelConfig") -> int: def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
# For GPTBigCode: # For GPTBigCode & Falcon:
if getattr(self.hf_config, "multi_query", False): # Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
if (getattr(self.hf_config, "multi_query", False) and
(self.hf_config.model_type == "falcon" and
not getattr(self.hf_config, "new_decoder_architecture", False))):
# Multi-query attention, only one KV head. # Multi-query attention, only one KV head.
return 1 return 1
# For Falcon: # For Falcon:
......
...@@ -314,14 +314,13 @@ class PagedAttentionWithRoPE(PagedAttention): ...@@ -314,14 +314,13 @@ class PagedAttentionWithRoPE(PagedAttention):
class PagedAttentionWithALiBi(PagedAttention): class PagedAttentionWithALiBi(PagedAttention):
"""PagedAttention with ALiBi attention bias.""" """PagedAttention with ALiBi attention bias."""
def __init__( def __init__(self,
self,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
scale: float, scale: float,
slopes: List[float], slopes: List[float],
) -> None: num_kv_heads: Optional[int] = None) -> None:
super().__init__(num_heads, head_size, scale) super().__init__(num_heads, head_size, scale, num_kv_heads)
assert len(slopes) == num_heads assert len(slopes) == num_heads
slopes = torch.tensor(slopes, dtype=torch.float32) slopes = torch.tensor(slopes, dtype=torch.float32)
...@@ -334,6 +333,11 @@ class PagedAttentionWithALiBi(PagedAttention): ...@@ -334,6 +333,11 @@ class PagedAttentionWithALiBi(PagedAttention):
# Generates ALiBi mask for each prompt. # Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens: for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len) bias = torch.arange(prompt_len)
# Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias = bias[None, :] - bias[:, None] bias = bias[None, :] - bias[:, None]
bias = bias.to(self.alibi_slopes.device) bias = bias.to(self.alibi_slopes.device)
...@@ -363,10 +367,17 @@ class PagedAttentionWithALiBi(PagedAttention): ...@@ -363,10 +367,17 @@ class PagedAttentionWithALiBi(PagedAttention):
Args: Args:
output: shape = [num_prompt_tokens, num_heads, head_size] output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size] key: shape = [num_prompt_tokens, num_kv_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size] value: shape = [num_prompt_tokens, num_kv_heads, head_size]
input_metadata: metadata for paged attention. input_metadata: metadata for paged attention.
""" """
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)
# FIXME(woosuk): Because xformers does not support dynamic sequence # FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by # lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts. # one. This is inefficient, especially when we have many short prompts.
...@@ -400,9 +411,10 @@ class PagedAttentionWithALiBi(PagedAttention): ...@@ -400,9 +411,10 @@ class PagedAttentionWithALiBi(PagedAttention):
Args: Args:
output: shape = [num_generation_tokens, num_heads, head_size] output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size] query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x, key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x] block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size] value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention. input_metadata: metadata for paged attention.
""" """
block_size = value_cache.shape[3] block_size = value_cache.shape[3]
......
...@@ -14,6 +14,7 @@ _MODEL_REGISTRY = { ...@@ -14,6 +14,7 @@ _MODEL_REGISTRY = {
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM, "BloomForCausalLM": BloomForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel, "GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM, "GPTJForCausalLM": GPTJForCausalLM,
...@@ -22,6 +23,7 @@ _MODEL_REGISTRY = { ...@@ -22,6 +23,7 @@ _MODEL_REGISTRY = {
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MPTForCausalLM": MPTForCausalLM, "MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM, "OPTForCausalLM": OPTForCausalLM,
"RWForCausalLM": FalconForCausalLM,
} }
......
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
from vllm.model_executor.models.bloom import BloomForCausalLM from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.falcon import FalconForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
from vllm.model_executor.models.gpt_j import GPTJForCausalLM from vllm.model_executor.models.gpt_j import GPTJForCausalLM
...@@ -12,6 +13,7 @@ __all__ = [ ...@@ -12,6 +13,7 @@ __all__ = [
"BaiChuanForCausalLM", "BaiChuanForCausalLM",
"BaichuanForCausalLM", "BaichuanForCausalLM",
"BloomForCausalLM", "BloomForCausalLM",
"FalconForCausalLM",
"GPT2LMHeadModel", "GPT2LMHeadModel",
"GPTBigCodeForCausalLM", "GPTBigCodeForCausalLM",
"GPTJForCausalLM", "GPTJForCausalLM",
......
This diff is collapsed.
...@@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None ...@@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks # rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None _DATA_PARALLEL_GLOBAL_RANKS = None
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
def initialize_model_parallel( def initialize_model_parallel(
tensor_model_parallel_size: int = 1, tensor_model_parallel_size: int = 1,
...@@ -196,20 +195,6 @@ def initialize_model_parallel( ...@@ -196,20 +195,6 @@ def initialize_model_parallel(
if rank in ranks: if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
def initialize_all_reduce_launcher(
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
global _ALL_REDUCE_LAUNCHER
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
max_num_tokens=max_num_tokens,
hidden_size=hidden_size,
dtype=dtype,
disable_graph=disable_graph,
)
def model_parallel_is_initialized(): def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized.""" """Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \ if _TENSOR_MODEL_PARALLEL_GROUP is None or \
...@@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank(): ...@@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank():
last_rank_local = get_pipeline_model_parallel_world_size() - 1 last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local] return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank(): def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline""" """Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \ assert _PIPELINE_GLOBAL_RANKS is not None, \
...@@ -485,10 +471,6 @@ def get_data_parallel_rank(): ...@@ -485,10 +471,6 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group.""" """Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group()) return torch.distributed.get_rank(group=get_data_parallel_group())
def get_all_reduce_launcher() -> 'GraphAllReduce':
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
return _ALL_REDUCE_LAUNCHER
def destroy_model_parallel(): def destroy_model_parallel():
"""Set the groups to none.""" """Set the groups to none."""
global _MODEL_PARALLEL_GROUP global _MODEL_PARALLEL_GROUP
...@@ -515,56 +497,3 @@ def destroy_model_parallel(): ...@@ -515,56 +497,3 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK = None _MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
class GraphAllReduce:
def __init__(
self,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.disable_graph = disable_graph
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size == 1:
return
self.group = get_tensor_model_parallel_group()
self.buffer = torch.empty(
size=(max_num_tokens, hidden_size),
dtype=dtype,
device='cuda',
)
# Build graphs for different number of tokens.
if not self.disable_graph:
self.graphs = {}
for num_tokens in range(8, max_num_tokens + 1, 8):
self.graphs[num_tokens] = self._build_graph(num_tokens)
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
# Warm up.
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
# Build graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
torch.distributed.all_reduce(
self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
return graph
def launch(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: x must be a slice of self.buffer.
num_tokens = x.shape[0]
if self.disable_graph:
torch.distributed.all_reduce(x, group=self.group)
else:
self.graphs[num_tokens].replay()
return x
...@@ -12,6 +12,7 @@ from .mappings import ( ...@@ -12,6 +12,7 @@ from .mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region, gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region, scatter_to_sequence_parallel_region,
) )
...@@ -38,7 +39,7 @@ __all__ = [ ...@@ -38,7 +39,7 @@ __all__ = [
"copy_to_tensor_model_parallel_region", "copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region", "gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region", "gather_from_sequence_parallel_region",
# "reduce_from_tensor_model_parallel_region", "reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region", "scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region", "scatter_to_sequence_parallel_region",
# random.py # random.py
......
...@@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter ...@@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_all_reduce_launcher,
) )
from .mappings import ( from .mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
...@@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module):
self.output_size = output_size self.output_size = output_size
self.gather_output = gather_output self.gather_output = gather_output
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size) self.output_size_per_partition = divide(output_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
if params_dtype is None: if params_dtype is None:
...@@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype: params_dtype:
use_cpu_initialization: use_cpu_initialization:
perform_initialization: perform_initialization:
reduce_results:
""" """
def __init__(self, input_size, output_size, *, def __init__(self, input_size, output_size, *,
...@@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=True,
reduce_results=True,
): ):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
...@@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module): ...@@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module):
self.input_size = input_size self.input_size = input_size
self.output_size = output_size self.output_size = output_size
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
...@@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module): ...@@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = input_ input_parallel = input_
else: else:
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
if get_tensor_model_parallel_world_size() == 1:
# Matrix multiply. # Matrix multiply.
output_ = F.linear(input_parallel, self.weight) output_parallel = F.linear(input_parallel, self.weight)
if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
else: else:
# Matrix multiply. output_ = output_parallel
all_reduce_launcher = get_all_reduce_launcher()
num_tokens = input_parallel.shape[0]
output_buffer = all_reduce_launcher.buffer[:num_tokens]
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
# All-reduce across all the partitions.
output_ = all_reduce_launcher.launch(output_buffer)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_
......
...@@ -5,6 +5,8 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import ...@@ -5,6 +5,8 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
_CONFIG_REGISTRY = { _CONFIG_REGISTRY = {
"mpt": MPTConfig, "mpt": MPTConfig,
"baichuan": BaiChuanConfig, "baichuan": BaiChuanConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
} }
......
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
__all__ = [ __all__ = [
"MPTConfig", "MPTConfig",
"BaiChuanConfig", "BaiChuanConfig",
"RWConfig",
] ]
# Adapted from
# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py
# Copyright 2023 The vLLM team.
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Falcon configuration"""
from transformers.configuration_utils import PretrainedConfig
class RWConfig(PretrainedConfig):
model_type = "falcon"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
"num_kv_heads": "n_head_kv",
}
def __init__(
self,
vocab_size=250880,
hidden_size=64,
n_layer=2,
n_head=8,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
hidden_dropout=0.0,
attention_dropout=0.0,
multi_query=True,
n_head_kv=None,
alibi=False,
bias=False,
parallel_attn=False,
new_decoder_architecture=False,
**kwargs,
) -> None:
self.vocab_size = vocab_size
# Backward compatibility with n_embed kwarg
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.multi_query = multi_query
self.n_head_kv = 1 if n_head_kv is None else n_head_kv
self.alibi = alibi
self.bias = bias
self.parallel_attn = parallel_attn
self.new_decoder_architecture = new_decoder_architecture
if self.hidden_size == 8192:
# Hack for falcon-40b
self.new_decoder_architecture = True
super().__init__(bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
**kwargs)
@property
def head_dim(self):
return self.hidden_size // self.n_head
@property
def rotary(self):
return not self.alibi
...@@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, ...@@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel, initialize_all_reduce_launcher) initialize_model_parallel)
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
...@@ -65,11 +65,6 @@ class Worker: ...@@ -65,11 +65,6 @@ class Worker:
# Initialize the model. # Initialize the model.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
self.model = get_model(self.model_config) self.model = get_model(self.model_config)
initialize_all_reduce_launcher(
self.scheduler_config.max_num_batched_tokens,
self.model_config.get_hidden_size(),
self.model_config.dtype,
)
@torch.inference_mode() @torch.inference_mode()
def profile_num_available_blocks( def profile_num_available_blocks(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment