Unverified Commit 37ca5581 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Optimize model execution with CUDA graph (#1926)


Co-authored-by: default avatarChen Shen <scv119@gmail.com>
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent eed74a55
......@@ -23,6 +23,7 @@ def main(args: argparse.Namespace):
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
)
sampling_params = SamplingParams(
......@@ -111,6 +112,9 @@ if __name__ == '__main__':
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
'--profile',
action='store_true',
......
......@@ -69,7 +69,8 @@ def run_vllm(
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int] = None,
max_model_len: Optional[int],
enforce_eager: bool,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
......@@ -81,6 +82,7 @@ def run_vllm(
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
enforce_eager=enforce_eager,
)
# Add the requests to the engine.
......@@ -204,7 +206,7 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len)
args.max_model_len, args.enforce_eager)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
......@@ -279,6 +281,9 @@ if __name__ == "__main__":
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
......
......@@ -12,3 +12,4 @@ fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette]
cupy-cuda12x # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. # FIXME: Fix this in setup.py.
......@@ -49,6 +49,12 @@ class ModelConfig:
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
"""
def __init__(
......@@ -65,6 +71,8 @@ class ModelConfig:
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
......@@ -76,6 +84,8 @@ class ModelConfig:
self.revision = revision
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub,
......@@ -95,6 +105,7 @@ class ModelConfig:
self._verify_load_format()
self._verify_tokenizer_mode()
self._verify_quantization()
self._verify_cuda_graph()
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
......@@ -169,6 +180,12 @@ class ModelConfig:
"optimized yet. The speed can be slower than "
"non-quantized models.")
def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
self.max_model_len)
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
......
......@@ -33,6 +33,8 @@ class EngineArgs:
revision: Optional[str] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
def __post_init__(self):
if self.tokenizer is None:
......@@ -182,6 +184,17 @@ class EngineArgs:
choices=['awq', 'gptq', 'squeezellm', None],
default=None,
help='Method used to quantize the weights')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
'will use eager mode and CUDA graph in hybrid '
'for maximal performance and flexibility.')
parser.add_argument('--max-context-len-to-capture',
type=int,
default=EngineArgs.max_context_len_to_capture,
help='maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
return parser
@classmethod
......@@ -200,7 +213,8 @@ class EngineArgs:
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len,
self.quantization)
self.quantization, self.enforce_eager,
self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space,
......
......@@ -17,7 +17,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
from vllm.utils import Counter, get_open_port
if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
......@@ -84,6 +84,7 @@ class LLMEngine:
f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
......@@ -189,6 +190,7 @@ class LLMEngine:
))
self._run_workers(
"init_model",
cupy_port=get_open_port(),
get_all_outputs=True,
)
self._run_workers(
......@@ -232,6 +234,9 @@ class LLMEngine:
# Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self._run_workers("warm_up_model")
@classmethod
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
......
import socket
from typing import Optional, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig
from vllm.logger import init_logger
from vllm.utils import is_hip
from vllm.utils import get_open_port, is_hip
logger = init_logger(__name__)
......@@ -43,12 +42,6 @@ if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def initialize_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,
......
......@@ -56,6 +56,12 @@ class LLM:
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
"""
def __init__(
......@@ -72,6 +78,8 @@ class LLM:
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
......@@ -89,6 +97,8 @@ class LLM:
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
......
......@@ -21,12 +21,14 @@ class InputMetadata:
max_context_len: Optional[int],
context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
) -> None:
self.prompt_lens = prompt_lens
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
self.is_prompt = len(prompt_lens) > 0
# Set during the execution of the first attention op.
......@@ -39,4 +41,5 @@ class InputMetadata:
f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables})")
f"block_tables={self.block_tables}, "
f"use_cuda_graph={self.use_cuda_graph})")
......@@ -24,13 +24,10 @@ class PagedAttention(nn.Module):
can either contain prompt tokens or generation tokens.
The class does the following:
1. Wait for the cache operations (e.g., swap, copy) to finish. The cache
operations are issued by the cache engine before executing the forward
pass of the model, and they are executed asynchronously.
2. Reshape and store the input key and value tensors in the KV cache.
3. Perform (multi-head/multi-query/grouped-query) attention using either
1. Reshape and store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention using either
xformers or the PagedAttention custom op.
4. Return the output tensor.
3. Return the output tensor.
"""
def __init__(
......@@ -67,7 +64,6 @@ class PagedAttention(nn.Module):
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
"""PagedAttention forward pass.
......@@ -80,7 +76,6 @@ class PagedAttention(nn.Module):
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
cache_event: event to wait for the cache operations to finish.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
......@@ -89,10 +84,6 @@ class PagedAttention(nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
slot_mapping = input_metadata.slot_mapping.flatten()
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
......@@ -104,7 +95,7 @@ class PagedAttention(nn.Module):
value,
key_cache,
value_cache,
slot_mapping,
input_metadata.slot_mapping.flatten(),
)
if input_metadata.is_prompt:
......@@ -165,15 +156,20 @@ class PagedAttention(nn.Module):
output = out.view_as(query)
else:
# Decoding run.
output = _paged_attention(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
if key_cache is not None and value_cache is not None:
output = _paged_attention(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
else:
# This happens during the initial memory profiling run for
# CUDA graphs.
output = torch.zeros_like(query)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
......
......@@ -158,14 +158,12 @@ class AquilaAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
......@@ -209,7 +207,6 @@ class AquilaDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
......@@ -219,7 +216,6 @@ class AquilaDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
......@@ -258,18 +254,15 @@ class AquilaModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
......@@ -296,10 +289,9 @@ class AquilaForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(
......
......@@ -172,15 +172,13 @@ class BaiChuanAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
......@@ -221,7 +219,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -236,7 +233,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
......@@ -273,19 +269,16 @@ class BaiChuanModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -311,10 +304,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(
......
......@@ -118,14 +118,12 @@ class BloomAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.dense(attn_output)
return output
......@@ -184,7 +182,6 @@ class BloomBlock(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
......@@ -201,7 +198,6 @@ class BloomBlock(nn.Module):
hidden_states=layernorm_output,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output)
......@@ -250,19 +246,16 @@ class BloomModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -288,10 +281,9 @@ class BloomForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(
......
......@@ -100,7 +100,6 @@ class GLMAttention(nn.Module):
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
......@@ -113,7 +112,6 @@ class GLMAttention(nn.Module):
key_cache,
value_cache,
input_metadata,
cache_event,
)
attn_output, _ = self.dense(context_layer)
return attn_output
......@@ -203,7 +201,6 @@ class GLMBlock(nn.Module):
position_ids: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
......@@ -214,7 +211,6 @@ class GLMBlock(nn.Module):
position_ids=position_ids,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Residual connection.
......@@ -269,17 +265,14 @@ class GLMTransformer(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
for i in range(self.num_layers):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i],
input_metadata=input_metadata,
cache_event=cache_event,
)
# Final layer norm.
if self.post_layer_norm:
......@@ -314,8 +307,7 @@ class ChatGLMModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
):
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
# Run encoder.
......@@ -324,9 +316,7 @@ class ChatGLMModel(nn.Module):
position_ids=position_ids,
kv_caches=kv_caches,
input_metadata=input_metadata,
cache_events=cache_events,
)
return hidden_states
......@@ -350,10 +340,9 @@ class ChatGLMForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(
......
......@@ -178,7 +178,6 @@ class FalconAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states)
if bias is not None:
......@@ -187,8 +186,7 @@ class FalconAttention(nn.Module):
if self.use_rotary:
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output, bias = self.dense(attn_output)
return attn_output, bias
......@@ -266,8 +264,7 @@ class FalconDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
):
) -> torch.Tensor:
residual = hidden_states
if self.config.new_decoder_architecture:
......@@ -282,7 +279,6 @@ class FalconDecoderLayer(nn.Module):
hidden_states=attention_layernorm_out,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias
......@@ -311,7 +307,6 @@ class FalconDecoderLayer(nn.Module):
mlp_output += mlp_bias
output = mlp_output + residual
return output
......@@ -349,18 +344,15 @@ class FalconModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -389,14 +381,12 @@ class FalconForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
input_metadata,
cache_events,
)
return hidden_states
......
......@@ -82,13 +82,12 @@ class GPT2Attention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
input_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
......@@ -148,7 +147,6 @@ class GPT2Block(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
......@@ -156,7 +154,6 @@ class GPT2Block(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# residual connection
hidden_states = attn_output + residual
......@@ -196,17 +193,14 @@ class GPT2Model(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -232,10 +226,9 @@ class GPT2LMHeadModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(
......
......@@ -95,7 +95,6 @@ class GPTBigCodeAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
......@@ -107,7 +106,7 @@ class GPTBigCodeAttention(nn.Module):
)
key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event)
input_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
......@@ -167,7 +166,6 @@ class GPTBigCodeBlock(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
......@@ -175,7 +173,6 @@ class GPTBigCodeBlock(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# residual connection
hidden_states = attn_output + residual
......@@ -215,17 +212,14 @@ class GPTBigCodeModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata,
cache_event)
hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -251,10 +245,9 @@ class GPTBigCodeForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(
......
......@@ -94,14 +94,12 @@ class GPTJAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
attn_output, _ = self.out_proj(attn_output)
return attn_output
......@@ -156,7 +154,6 @@ class GPTJBlock(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
......@@ -165,7 +162,6 @@ class GPTJBlock(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual
......@@ -196,18 +192,15 @@ class GPTJModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -238,10 +231,9 @@ class GPTJForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(
......
......@@ -92,14 +92,12 @@ class GPTNeoXAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.dense(attn_output)
return output
......@@ -155,7 +153,6 @@ class GPTNeoXLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention(
......@@ -163,7 +160,6 @@ class GPTNeoXLayer(nn.Module):
hidden_states=attn_input,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
if self.use_parallel_residual:
......@@ -210,18 +206,15 @@ class GPTNeoXModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
......@@ -250,10 +243,9 @@ class GPTNeoXForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(
......
......@@ -110,14 +110,12 @@ class InternLMAttention(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
......@@ -160,7 +158,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -175,7 +172,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# Fully Connected
......@@ -214,19 +210,16 @@ class InternLMModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -253,10 +246,9 @@ class InternLMForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
input_metadata)
return hidden_states
def sample(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment