Unverified Commit e41f0670 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add support for BLOOM (#331)

parent d6fa1be3
...@@ -41,6 +41,7 @@ vLLM is flexible and easy to use with: ...@@ -41,6 +41,7 @@ vLLM is flexible and easy to use with:
vLLM seamlessly supports many Huggingface models, including the following architectures: vLLM seamlessly supports many Huggingface models, including the following architectures:
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
- GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.)
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.) - GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
......
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/util/Optional.h>
void single_query_cached_kv_attention( void single_query_cached_kv_attention(
torch::Tensor& out, torch::Tensor& out,
...@@ -9,7 +10,8 @@ void single_query_cached_kv_attention( ...@@ -9,7 +10,8 @@ void single_query_cached_kv_attention(
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& context_lens,
int block_size, int block_size,
int max_context_len); int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def( m.def(
......
...@@ -80,6 +80,7 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -80,6 +80,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride) { const int q_stride) {
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE;
...@@ -91,6 +92,7 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -91,6 +92,7 @@ __global__ void single_query_cached_kv_attention_kernel(
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group // The vector size is configured in such a way that the threads in a thread group
...@@ -167,12 +169,14 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -167,12 +169,14 @@ __global__ void single_query_cached_kv_attention_kernel(
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
const float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs); float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs, k_vecs);
const bool mask = token_idx >= context_len; // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk; logits[token_idx] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
...@@ -328,6 +332,7 @@ __global__ void single_query_cached_kv_attention_kernel( ...@@ -328,6 +332,7 @@ __global__ void single_query_cached_kv_attention_kernel(
block_tables_ptr, \ block_tables_ptr, \
context_lens_ptr, \ context_lens_ptr, \
max_num_blocks_per_seq, \ max_num_blocks_per_seq, \
alibi_slopes_ptr, \
query_stride); query_stride);
// TODO(woosuk): Tune NUM_THREADS. // TODO(woosuk): Tune NUM_THREADS.
...@@ -343,7 +348,8 @@ void single_query_cached_kv_attention_launcher( ...@@ -343,7 +348,8 @@ void single_query_cached_kv_attention_launcher(
float scale, float scale,
torch::Tensor& block_tables, torch::Tensor& block_tables,
torch::Tensor& context_lens, torch::Tensor& context_lens,
int max_context_len) { int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_seqs = query.size(0); int num_seqs = query.size(0);
int num_heads = query.size(1); int num_heads = query.size(1);
int head_size = query.size(2); int head_size = query.size(2);
...@@ -353,6 +359,11 @@ void single_query_cached_kv_attention_launcher( ...@@ -353,6 +359,11 @@ void single_query_cached_kv_attention_launcher(
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
assert(head_size % thread_group_size == 0); assert(head_size % thread_group_size == 0);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
...@@ -411,7 +422,8 @@ void single_query_cached_kv_attention_launcher( ...@@ -411,7 +422,8 @@ void single_query_cached_kv_attention_launcher(
scale, \ scale, \
block_tables, \ block_tables, \
context_lens, \ context_lens, \
max_context_len); max_context_len, \
alibi_slopes);
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes // NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256. // 1, 2, 4, 64, 128, 256.
...@@ -458,7 +470,8 @@ void single_query_cached_kv_attention( ...@@ -458,7 +470,8 @@ void single_query_cached_kv_attention(
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs] torch::Tensor& context_lens, // [num_seqs]
int block_size, int block_size,
int max_context_len) { int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
if (query.dtype() == at::ScalarType::Float) { if (query.dtype() == at::ScalarType::Float) {
CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float);
} else if (query.dtype() == at::ScalarType::Half) { } else if (query.dtype() == at::ScalarType::Half) {
......
...@@ -14,6 +14,9 @@ Alongside each architecture, we include some popular models that use it. ...@@ -14,6 +14,9 @@ Alongside each architecture, we include some popular models that use it.
* - Architecture * - Architecture
- Models - Models
- Example HuggingFace Models - Example HuggingFace Models
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
* - :code:`GPT2LMHeadModel` * - :code:`GPT2LMHeadModel`
- GPT-2 - GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc. - :code:`gpt2`, :code:`gpt2-xl`, etc.
......
...@@ -216,6 +216,7 @@ def run_single_query_cached_kv_attention( ...@@ -216,6 +216,7 @@ def run_single_query_cached_kv_attention(
context_lens, context_lens,
block_size, block_size,
max_context_len, max_context_len,
None, # ALiBi slopes.
) )
ref_output = torch.empty_like(query) ref_output = torch.empty_like(query)
......
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask from xformers.ops import AttentionBias
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
...@@ -38,7 +38,6 @@ class InputMetadata: ...@@ -38,7 +38,6 @@ class InputMetadata:
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.block_tables = block_tables self.block_tables = block_tables
self.attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = sum(prompt_lens) self.num_prompt_tokens = sum(prompt_lens)
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]
...@@ -50,6 +49,9 @@ class InputMetadata: ...@@ -50,6 +49,9 @@ class InputMetadata:
assert block_tables.shape[0] == self.num_generation_tokens assert block_tables.shape[0] == self.num_generation_tokens
assert context_lens.shape[0] == self.num_generation_tokens assert context_lens.shape[0] == self.num_generation_tokens
# Set during the execution of the first attention op.
self.attn_bias: List[AttentionBias] = []
def __repr__(self) -> str: def __repr__(self) -> str:
# Print only useful metadata. # Print only useful metadata.
return (f'InputMetadata(' return (f'InputMetadata('
......
"""Multi-head attention.""" """Multi-head attention."""
from typing import Optional from typing import List, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from xformers import ops as xops from xformers import ops as xops
from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
LowerTriangularMaskWithTensorBias)
from vllm import attention_ops from vllm import attention_ops
from vllm import cache_ops from vllm import cache_ops
...@@ -53,13 +55,21 @@ class PagedAttention(nn.Module): ...@@ -53,13 +55,21 @@ class PagedAttention(nn.Module):
raise ValueError(f"head_size ({self.head_size}) is not supported. " raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
prompt_lens = input_metadata.prompt_lens
attn_bias = BlockDiagonalCausalMask.from_seqlens(prompt_lens)
input_metadata.attn_bias.append(attn_bias)
def multi_query_kv_attention( def multi_query_kv_attention(
self, self,
output: torch.Tensor, output: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_bias: xops.AttentionBias, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Normal attention for the prompt tokens. """Normal attention for the prompt tokens.
...@@ -68,13 +78,14 @@ class PagedAttention(nn.Module): ...@@ -68,13 +78,14 @@ class PagedAttention(nn.Module):
query: shape = [num_prompt_tokens, num_heads, head_size] query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size] key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size] value: shape = [num_prompt_tokens, num_heads, head_size]
input_metadata: metadata for paged attention.
""" """
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query.unsqueeze(0),
key.unsqueeze(0), key.unsqueeze(0),
value.unsqueeze(0), value.unsqueeze(0),
attn_bias=attn_bias, attn_bias=input_metadata.attn_bias[0],
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
op=self.attn_op, op=self.attn_op,
...@@ -112,6 +123,7 @@ class PagedAttention(nn.Module): ...@@ -112,6 +123,7 @@ class PagedAttention(nn.Module):
input_metadata.context_lens, input_metadata.context_lens,
block_size, block_size,
input_metadata.max_context_len, input_metadata.max_context_len,
None, # alibi_slopes
) )
def forward( def forward(
...@@ -154,12 +166,13 @@ class PagedAttention(nn.Module): ...@@ -154,12 +166,13 @@ class PagedAttention(nn.Module):
# Compute the attention op for prompts. # Compute the attention op for prompts.
num_prompt_tokens = input_metadata.num_prompt_tokens num_prompt_tokens = input_metadata.num_prompt_tokens
if num_prompt_tokens > 0: if num_prompt_tokens > 0:
self.set_attn_bias(input_metadata)
self.multi_query_kv_attention( self.multi_query_kv_attention(
output[:num_prompt_tokens], output[:num_prompt_tokens],
query[:num_prompt_tokens], query[:num_prompt_tokens],
key[:num_prompt_tokens], key[:num_prompt_tokens],
value[:num_prompt_tokens], value[:num_prompt_tokens],
input_metadata.attn_bias, input_metadata,
) )
# Wait until the cache op is done. # Wait until the cache op is done.
...@@ -219,7 +232,8 @@ class PagedAttentionWithRoPE(PagedAttention): ...@@ -219,7 +232,8 @@ class PagedAttentionWithRoPE(PagedAttention):
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
# FIXME(woosuk): This assumes that we configure the default dtype when # FIXME(woosuk): This assumes that we configure the default dtype when
# initializing the model. Make it more robust. # initializing the model.
# TODO(woosuk): Make it more robust.
torch_dtype = torch.get_default_dtype() torch_dtype = torch.get_default_dtype()
cache = cache.to(torch_dtype) cache = cache.to(torch_dtype)
# Embedding size: [max_position, rotary_dim] # Embedding size: [max_position, rotary_dim]
...@@ -271,3 +285,112 @@ class PagedAttentionWithRoPE(PagedAttention): ...@@ -271,3 +285,112 @@ class PagedAttentionWithRoPE(PagedAttention):
input_metadata, input_metadata,
cache_event, cache_event,
) )
class PagedAttentionWithALiBi(PagedAttention):
"""PagedAttention with ALiBi attention bias."""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
slopes: List[float],
) -> None:
super().__init__(num_heads, head_size, scale)
assert len(slopes) == num_heads
slopes = torch.tensor(slopes, dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes, persistent=False)
def set_attn_bias(self, input_metadata: InputMetadata) -> None:
if input_metadata.attn_bias:
# Already set by a previous layer.
return
# Generates ALiBi mask for each prompt.
for prompt_len in input_metadata.prompt_lens:
bias = torch.arange(prompt_len)
bias = bias[None, :] - bias[:, None]
bias = bias.to(self.alibi_slopes.device)
# When using custom attention bias, xformers requires the bias to
# be sliced from a tensor whose length is a multiple of 8.
padded_len = (prompt_len + 7) // 8 * 8
bias = torch.empty(
self.num_heads,
padded_len,
padded_len,
device=self.alibi_slopes.device,
)[:, :prompt_len, :prompt_len].copy_(bias)
bias.mul_(self.alibi_slopes[:, None, None])
attn_bias = LowerTriangularMaskWithTensorBias(bias)
input_metadata.attn_bias.append(attn_bias)
def multi_query_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Attention with ALiBi bias for the prompt tokens.
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
input_metadata: metadata for paged attention.
"""
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
start = 0
for i, prompt_len in enumerate(input_metadata.prompt_lens):
end = start + prompt_len
out = xops.memory_efficient_attention_forward(
query[None, start:end],
key[None, start:end],
value[None, start:end],
attn_bias=input_metadata.attn_bias[i],
p=0.0,
scale=self.scale,
op=self.attn_op,
)
# TODO(woosuk): Unnecessary copy. Optimize.
output[start:end].copy_(out.squeeze(0))
start += prompt_len
return output
def single_query_cached_kv_attention(
self,
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
input_metadata: InputMetadata,
) -> None:
"""PagedAttention with ALiBi bias for the generation tokens.
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
input_metadata: metadata for paged attention.
"""
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
self.alibi_slopes,
)
...@@ -6,13 +6,12 @@ import torch.nn as nn ...@@ -6,13 +6,12 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import (GPT2LMHeadModel, GPTBigCodeForCausalLM, from vllm.model_executor.models import * # pylint: disable=wildcard-import
GPTNeoXForCausalLM, LlamaForCausalLM,
OPTForCausalLM)
from vllm.model_executor.weight_utils import initialize_dummy_weights from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
"BloomForCausalLM": BloomForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel, "GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM, "GPTNeoXForCausalLM": GPTNeoXForCausalLM,
......
from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
...@@ -5,6 +6,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM ...@@ -5,6 +6,7 @@ from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM from vllm.model_executor.models.opt import OPTForCausalLM
__all__ = [ __all__ = [
"BloomForCausalLM",
"GPT2LMHeadModel", "GPT2LMHeadModel",
"GPTBigCodeForCausalLM", "GPTBigCodeForCausalLM",
"GPTNeoXForCausalLM", "GPTNeoXForCausalLM",
......
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The CacheFlow team.
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BLOOM model compatible with HuggingFace weights.
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
import math
from typing import Dict, List, Optional, Tuple
import torch
from torch import nn
from transformers import BloomConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BloomAttention(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.total_num_heads = config.n_head
self.head_dim = self.hidden_size // self.total_num_heads
assert self.head_dim * self.total_num_heads == self.hidden_size
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
self.query_key_value = ColumnParallelLinear(
self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
output, _ = self.dense(attn_output)
return output
class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
gather_output=False,
perform_initialization=False)
self.act = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(4 * hidden_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x)
x = self.act(x)
x, _ = self.dense_4h_to_h(x)
return x
class BloomBlock(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Layer norm post the self attention.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# Self attention.
attention_output = self.self_attention(
position_ids=position_ids,
hidden_states=layernorm_output,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output)
# Get residual
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = attention_output
# MLP.
output = self.mlp(layernorm_output) + residual
return output
class BloomModel(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
self.embed_dim = config.hidden_size
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
self.word_embeddings_layernorm = nn.LayerNorm(
self.embed_dim, eps=config.layer_norm_epsilon)
# Transformer blocks
self.h = nn.ModuleList(
[BloomBlock(config) for _ in range(config.num_hidden_layers)])
# Final Layer Norm
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class BloomForCausalLM(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
self.config = config
self.transformer = BloomModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.word_embeddings.weight
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens
_column_parallel_weights = [
"word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"
]
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if not name.startswith("transformer."):
name = "transformer." + name
param = state_dict[name]
if "query_key_value" in name:
# NOTE(woosuk): BLOOM's fused QKV has the shape of
# [num_heads * 3 * head_size, hidden_size], while the
# required shape is [3 * num_heads * head_size, hidden_size].
# Thus, we need weight conversion.
shard_size = param.shape[0]
start = shard_size * tp_rank
end = shard_size * (tp_rank + 1)
loaded_weight = loaded_weight[start:end]
num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // num_heads
if "query_key_value.weight" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size,
hidden_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif "query_key_value.bias" in name:
loaded_weight = loaded_weight.view(-1, 3, head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected weight name: {name}")
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)
...@@ -80,7 +80,6 @@ class GPTNeoXAttention(nn.Module): ...@@ -80,7 +80,6 @@ class GPTNeoXAttention(nn.Module):
cache_event: Optional[torch.cuda.Event], cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache, attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment