Unverified Commit 37ca5581 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Optimize model execution with CUDA graph (#1926)


Co-authored-by: default avatarChen Shen <scv119@gmail.com>
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent eed74a55
...@@ -23,6 +23,7 @@ def main(args: argparse.Namespace): ...@@ -23,6 +23,7 @@ def main(args: argparse.Namespace):
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code, trust_remote_code=args.trust_remote_code,
dtype=args.dtype, dtype=args.dtype,
enforce_eager=args.enforce_eager,
) )
sampling_params = SamplingParams( sampling_params = SamplingParams(
...@@ -111,6 +112,9 @@ if __name__ == '__main__': ...@@ -111,6 +112,9 @@ if __name__ == '__main__':
'The "auto" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument( parser.add_argument(
'--profile', '--profile',
action='store_true', action='store_true',
......
...@@ -69,7 +69,8 @@ def run_vllm( ...@@ -69,7 +69,8 @@ def run_vllm(
use_beam_search: bool, use_beam_search: bool,
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
max_model_len: Optional[int] = None, max_model_len: Optional[int],
enforce_eager: bool,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
...@@ -81,6 +82,7 @@ def run_vllm( ...@@ -81,6 +82,7 @@ def run_vllm(
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
dtype=dtype, dtype=dtype,
max_model_len=max_model_len, max_model_len=max_model_len,
enforce_eager=enforce_eager,
) )
# Add the requests to the engine. # Add the requests to the engine.
...@@ -204,7 +206,7 @@ def main(args: argparse.Namespace): ...@@ -204,7 +206,7 @@ def main(args: argparse.Namespace):
args.quantization, args.tensor_parallel_size, args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.trust_remote_code, args.dtype,
args.max_model_len) args.max_model_len, args.enforce_eager)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
...@@ -279,6 +281,9 @@ if __name__ == "__main__": ...@@ -279,6 +281,9 @@ if __name__ == "__main__":
'The "auto" option will use FP16 precision ' 'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision ' 'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.') 'for BF16 models.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
......
...@@ -12,3 +12,4 @@ fastapi ...@@ -12,3 +12,4 @@ fastapi
uvicorn[standard] uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server. pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette] aioprometheus[starlette]
cupy-cuda12x # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead. # FIXME: Fix this in setup.py.
...@@ -49,6 +49,12 @@ class ModelConfig: ...@@ -49,6 +49,12 @@ class ModelConfig:
output). If None, will be derived from the model. output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized. weights. If None, we assume the model weights are not quantized.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
""" """
def __init__( def __init__(
...@@ -65,6 +71,8 @@ class ModelConfig: ...@@ -65,6 +71,8 @@ class ModelConfig:
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -76,6 +84,8 @@ class ModelConfig: ...@@ -76,6 +84,8 @@ class ModelConfig:
self.revision = revision self.revision = revision
self.tokenizer_revision = tokenizer_revision self.tokenizer_revision = tokenizer_revision
self.quantization = quantization self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub, # download model from ModelScope hub,
...@@ -95,6 +105,7 @@ class ModelConfig: ...@@ -95,6 +105,7 @@ class ModelConfig:
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph()
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
load_format = self.load_format.lower() load_format = self.load_format.lower()
...@@ -169,6 +180,12 @@ class ModelConfig: ...@@ -169,6 +180,12 @@ class ModelConfig:
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
"non-quantized models.") "non-quantized models.")
def _verify_cuda_graph(self) -> None:
if self.max_context_len_to_capture is None:
self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
self.max_model_len)
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
......
...@@ -33,6 +33,8 @@ class EngineArgs: ...@@ -33,6 +33,8 @@ class EngineArgs:
revision: Optional[str] = None revision: Optional[str] = None
tokenizer_revision: Optional[str] = None tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None quantization: Optional[str] = None
enforce_eager: bool = False
max_context_len_to_capture: int = 8192
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -182,6 +184,17 @@ class EngineArgs: ...@@ -182,6 +184,17 @@ class EngineArgs:
choices=['awq', 'gptq', 'squeezellm', None], choices=['awq', 'gptq', 'squeezellm', None],
default=None, default=None,
help='Method used to quantize the weights') help='Method used to quantize the weights')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
'will use eager mode and CUDA graph in hybrid '
'for maximal performance and flexibility.')
parser.add_argument('--max-context-len-to-capture',
type=int,
default=EngineArgs.max_context_len_to_capture,
help='maximum context length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode.')
return parser return parser
@classmethod @classmethod
...@@ -200,7 +213,8 @@ class EngineArgs: ...@@ -200,7 +213,8 @@ class EngineArgs:
self.download_dir, self.load_format, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len, self.tokenizer_revision, self.max_model_len,
self.quantization) self.quantization, self.enforce_eager,
self.max_context_len_to_capture)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.swap_space,
......
...@@ -17,7 +17,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, ...@@ -17,7 +17,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceOutput, SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter, get_open_port
if ray: if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group from ray.air.util.torch_dist import init_torch_dist_process_group
...@@ -84,6 +84,7 @@ class LLMEngine: ...@@ -84,6 +84,7 @@ class LLMEngine:
f"load_format={model_config.load_format}, " f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, " f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, "
f"seed={model_config.seed})") f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.
...@@ -189,6 +190,7 @@ class LLMEngine: ...@@ -189,6 +190,7 @@ class LLMEngine:
)) ))
self._run_workers( self._run_workers(
"init_model", "init_model",
cupy_port=get_open_port(),
get_all_outputs=True, get_all_outputs=True,
) )
self._run_workers( self._run_workers(
...@@ -232,6 +234,9 @@ class LLMEngine: ...@@ -232,6 +234,9 @@ class LLMEngine:
# Initialize the cache. # Initialize the cache.
self._run_workers("init_cache_engine", cache_config=self.cache_config) self._run_workers("init_cache_engine", cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self._run_workers("warm_up_model")
@classmethod @classmethod
def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine":
......
import socket
from typing import Optional, Tuple, TYPE_CHECKING from typing import Optional, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import get_open_port, is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -43,12 +42,6 @@ if TYPE_CHECKING: ...@@ -43,12 +42,6 @@ if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def initialize_cluster( def initialize_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
engine_use_ray: bool = False, engine_use_ray: bool = False,
......
...@@ -56,6 +56,12 @@ class LLM: ...@@ -56,6 +56,12 @@ class LLM:
when their `best_of` sampling parameters are larger than 1. If all when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0. requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors. Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
""" """
def __init__( def __init__(
...@@ -72,6 +78,8 @@ class LLM: ...@@ -72,6 +78,8 @@ class LLM:
seed: int = 0, seed: int = 0,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: int = 4, swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
**kwargs, **kwargs,
) -> None: ) -> None:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
...@@ -89,6 +97,8 @@ class LLM: ...@@ -89,6 +97,8 @@ class LLM:
seed=seed, seed=seed,
gpu_memory_utilization=gpu_memory_utilization, gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space, swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
**kwargs, **kwargs,
) )
self.llm_engine = LLMEngine.from_engine_args(engine_args) self.llm_engine = LLMEngine.from_engine_args(engine_args)
......
...@@ -21,12 +21,14 @@ class InputMetadata: ...@@ -21,12 +21,14 @@ class InputMetadata:
max_context_len: Optional[int], max_context_len: Optional[int],
context_lens: Optional[torch.Tensor], context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor], block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
) -> None: ) -> None:
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.slot_mapping = slot_mapping self.slot_mapping = slot_mapping
self.context_lens = context_lens self.context_lens = context_lens
self.block_tables = block_tables self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
self.is_prompt = len(prompt_lens) > 0 self.is_prompt = len(prompt_lens) > 0
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
...@@ -39,4 +41,5 @@ class InputMetadata: ...@@ -39,4 +41,5 @@ class InputMetadata:
f"max_context_len={self.max_context_len}, " f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, " f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, " f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables})") f"block_tables={self.block_tables}, "
f"use_cuda_graph={self.use_cuda_graph})")
...@@ -24,13 +24,10 @@ class PagedAttention(nn.Module): ...@@ -24,13 +24,10 @@ class PagedAttention(nn.Module):
can either contain prompt tokens or generation tokens. can either contain prompt tokens or generation tokens.
The class does the following: The class does the following:
1. Wait for the cache operations (e.g., swap, copy) to finish. The cache 1. Reshape and store the input key and value tensors in the KV cache.
operations are issued by the cache engine before executing the forward 2. Perform (multi-head/multi-query/grouped-query) attention using either
pass of the model, and they are executed asynchronously.
2. Reshape and store the input key and value tensors in the KV cache.
3. Perform (multi-head/multi-query/grouped-query) attention using either
xformers or the PagedAttention custom op. xformers or the PagedAttention custom op.
4. Return the output tensor. 3. Return the output tensor.
""" """
def __init__( def __init__(
...@@ -67,7 +64,6 @@ class PagedAttention(nn.Module): ...@@ -67,7 +64,6 @@ class PagedAttention(nn.Module):
key_cache: Optional[torch.Tensor], key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor], value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
"""PagedAttention forward pass. """PagedAttention forward pass.
...@@ -80,7 +76,6 @@ class PagedAttention(nn.Module): ...@@ -80,7 +76,6 @@ class PagedAttention(nn.Module):
value_cache: shape = [num_blocks, num_kv_heads, head_size, value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size] block_size]
input_metadata: metadata for the inputs. input_metadata: metadata for the inputs.
cache_event: event to wait for the cache operations to finish.
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
...@@ -89,10 +84,6 @@ class PagedAttention(nn.Module): ...@@ -89,10 +84,6 @@ class PagedAttention(nn.Module):
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size)
slot_mapping = input_metadata.slot_mapping.flatten()
if cache_event is not None:
cache_event.wait()
# Reshape the keys and values and store them in the cache. # Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value # If key_cache and value_cache are not provided, the new key and value
...@@ -104,7 +95,7 @@ class PagedAttention(nn.Module): ...@@ -104,7 +95,7 @@ class PagedAttention(nn.Module):
value, value,
key_cache, key_cache,
value_cache, value_cache,
slot_mapping, input_metadata.slot_mapping.flatten(),
) )
if input_metadata.is_prompt: if input_metadata.is_prompt:
...@@ -165,15 +156,20 @@ class PagedAttention(nn.Module): ...@@ -165,15 +156,20 @@ class PagedAttention(nn.Module):
output = out.view_as(query) output = out.view_as(query)
else: else:
# Decoding run. # Decoding run.
output = _paged_attention( if key_cache is not None and value_cache is not None:
query, output = _paged_attention(
key_cache, query,
value_cache, key_cache,
input_metadata, value_cache,
self.num_kv_heads, input_metadata,
self.scale, self.num_kv_heads,
self.alibi_slopes, self.scale,
) self.alibi_slopes,
)
else:
# This happens during the initial memory profiling run for
# CUDA graphs.
output = torch.zeros_like(query)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size) return output.view(batch_size, seq_len, hidden_size)
......
...@@ -158,14 +158,12 @@ class AquilaAttention(nn.Module): ...@@ -158,14 +158,12 @@ class AquilaAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -209,7 +207,6 @@ class AquilaDecoderLayer(nn.Module): ...@@ -209,7 +207,6 @@ class AquilaDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -219,7 +216,6 @@ class AquilaDecoderLayer(nn.Module): ...@@ -219,7 +216,6 @@ class AquilaDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -258,18 +254,15 @@ class AquilaModel(nn.Module): ...@@ -258,18 +254,15 @@ class AquilaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
...@@ -296,10 +289,9 @@ class AquilaForCausalLM(nn.Module): ...@@ -296,10 +289,9 @@ class AquilaForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -172,15 +172,13 @@ class BaiChuanAttention(nn.Module): ...@@ -172,15 +172,13 @@ class BaiChuanAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.postion_embedding != "ALIBI": if self.postion_embedding != "ALIBI":
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -221,7 +219,6 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -221,7 +219,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -236,7 +233,6 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -236,7 +233,6 @@ class BaiChuanDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# Fully Connected # Fully Connected
...@@ -273,19 +269,16 @@ class BaiChuanModel(nn.Module): ...@@ -273,19 +269,16 @@ class BaiChuanModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -311,10 +304,9 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -311,10 +304,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -118,14 +118,12 @@ class BloomAttention(nn.Module): ...@@ -118,14 +118,12 @@ class BloomAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
del position_ids # Unused. del position_ids # Unused.
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
...@@ -184,7 +182,6 @@ class BloomBlock(nn.Module): ...@@ -184,7 +182,6 @@ class BloomBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
...@@ -201,7 +198,6 @@ class BloomBlock(nn.Module): ...@@ -201,7 +198,6 @@ class BloomBlock(nn.Module):
hidden_states=layernorm_output, hidden_states=layernorm_output,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
attention_output = attention_output + residual attention_output = attention_output + residual
layernorm_output = self.post_attention_layernorm(attention_output) layernorm_output = self.post_attention_layernorm(attention_output)
...@@ -250,19 +246,16 @@ class BloomModel(nn.Module): ...@@ -250,19 +246,16 @@ class BloomModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(hidden_states) hidden_states = self.word_embeddings_layernorm(hidden_states)
for i in range(len(self.h)): for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -288,10 +281,9 @@ class BloomForCausalLM(nn.Module): ...@@ -288,10 +281,9 @@ class BloomForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -100,7 +100,6 @@ class GLMAttention(nn.Module): ...@@ -100,7 +100,6 @@ class GLMAttention(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -113,7 +112,6 @@ class GLMAttention(nn.Module): ...@@ -113,7 +112,6 @@ class GLMAttention(nn.Module):
key_cache, key_cache,
value_cache, value_cache,
input_metadata, input_metadata,
cache_event,
) )
attn_output, _ = self.dense(context_layer) attn_output, _ = self.dense(context_layer)
return attn_output return attn_output
...@@ -203,7 +201,6 @@ class GLMBlock(nn.Module): ...@@ -203,7 +201,6 @@ class GLMBlock(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
# hidden_states: [num_tokens, h] # hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
...@@ -214,7 +211,6 @@ class GLMBlock(nn.Module): ...@@ -214,7 +211,6 @@ class GLMBlock(nn.Module):
position_ids=position_ids, position_ids=position_ids,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# Residual connection. # Residual connection.
...@@ -269,17 +265,14 @@ class GLMTransformer(nn.Module): ...@@ -269,17 +265,14 @@ class GLMTransformer(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
for i in range(self.num_layers): for i in range(self.num_layers):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
hidden_states=hidden_states, hidden_states=hidden_states,
position_ids=position_ids, position_ids=position_ids,
kv_cache=kv_caches[i], kv_cache=kv_caches[i],
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# Final layer norm. # Final layer norm.
if self.post_layer_norm: if self.post_layer_norm:
...@@ -314,8 +307,7 @@ class ChatGLMModel(nn.Module): ...@@ -314,8 +307,7 @@ class ChatGLMModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor:
):
inputs_embeds = self.embedding(input_ids) inputs_embeds = self.embedding(input_ids)
# Run encoder. # Run encoder.
...@@ -324,9 +316,7 @@ class ChatGLMModel(nn.Module): ...@@ -324,9 +316,7 @@ class ChatGLMModel(nn.Module):
position_ids=position_ids, position_ids=position_ids,
kv_caches=kv_caches, kv_caches=kv_caches,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_events=cache_events,
) )
return hidden_states return hidden_states
...@@ -350,10 +340,9 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -350,10 +340,9 @@ class ChatGLMForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -178,7 +178,6 @@ class FalconAttention(nn.Module): ...@@ -178,7 +178,6 @@ class FalconAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, bias = self.query_key_value(hidden_states) qkv, bias = self.query_key_value(hidden_states)
if bias is not None: if bias is not None:
...@@ -187,8 +186,7 @@ class FalconAttention(nn.Module): ...@@ -187,8 +186,7 @@ class FalconAttention(nn.Module):
if self.use_rotary: if self.use_rotary:
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
attn_output, bias = self.dense(attn_output) attn_output, bias = self.dense(attn_output)
return attn_output, bias return attn_output, bias
...@@ -266,8 +264,7 @@ class FalconDecoderLayer(nn.Module): ...@@ -266,8 +264,7 @@ class FalconDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor:
):
residual = hidden_states residual = hidden_states
if self.config.new_decoder_architecture: if self.config.new_decoder_architecture:
...@@ -282,7 +279,6 @@ class FalconDecoderLayer(nn.Module): ...@@ -282,7 +279,6 @@ class FalconDecoderLayer(nn.Module):
hidden_states=attention_layernorm_out, hidden_states=attention_layernorm_out,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
if self.reduce_row_parallel_results and attention_bias is not None: if self.reduce_row_parallel_results and attention_bias is not None:
attention_output += attention_bias attention_output += attention_bias
...@@ -311,7 +307,6 @@ class FalconDecoderLayer(nn.Module): ...@@ -311,7 +307,6 @@ class FalconDecoderLayer(nn.Module):
mlp_output += mlp_bias mlp_output += mlp_bias
output = mlp_output + residual output = mlp_output + residual
return output return output
...@@ -349,18 +344,15 @@ class FalconModel(nn.Module): ...@@ -349,18 +344,15 @@ class FalconModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)): for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -389,14 +381,12 @@ class FalconForCausalLM(nn.Module): ...@@ -389,14 +381,12 @@ class FalconForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
input_metadata, input_metadata,
cache_events,
) )
return hidden_states return hidden_states
......
...@@ -82,13 +82,12 @@ class GPT2Attention(nn.Module): ...@@ -82,13 +82,12 @@ class GPT2Attention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache, attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event) input_metadata)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -148,7 +147,6 @@ class GPT2Block(nn.Module): ...@@ -148,7 +147,6 @@ class GPT2Block(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
...@@ -156,7 +154,6 @@ class GPT2Block(nn.Module): ...@@ -156,7 +154,6 @@ class GPT2Block(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
...@@ -196,17 +193,14 @@ class GPT2Model(nn.Module): ...@@ -196,17 +193,14 @@ class GPT2Model(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)): for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata, hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
cache_event)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -232,10 +226,9 @@ class GPT2LMHeadModel(nn.Module): ...@@ -232,10 +226,9 @@ class GPT2LMHeadModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -95,7 +95,6 @@ class GPTBigCodeAttention(nn.Module): ...@@ -95,7 +95,6 @@ class GPTBigCodeAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split( q, k, v = qkv.split(
...@@ -107,7 +106,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -107,7 +106,7 @@ class GPTBigCodeAttention(nn.Module):
) )
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache, attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event) input_metadata)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -167,7 +166,6 @@ class GPTBigCodeBlock(nn.Module): ...@@ -167,7 +166,6 @@ class GPTBigCodeBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
...@@ -175,7 +173,6 @@ class GPTBigCodeBlock(nn.Module): ...@@ -175,7 +173,6 @@ class GPTBigCodeBlock(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
...@@ -215,17 +212,14 @@ class GPTBigCodeModel(nn.Module): ...@@ -215,17 +212,14 @@ class GPTBigCodeModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)): for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata, hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
cache_event)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -251,10 +245,9 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -251,10 +245,9 @@ class GPTBigCodeForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -94,14 +94,12 @@ class GPTJAttention(nn.Module): ...@@ -94,14 +94,12 @@ class GPTJAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
attn_output, _ = self.out_proj(attn_output) attn_output, _ = self.out_proj(attn_output)
return attn_output return attn_output
...@@ -156,7 +154,6 @@ class GPTJBlock(nn.Module): ...@@ -156,7 +154,6 @@ class GPTJBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
...@@ -165,7 +162,6 @@ class GPTJBlock(nn.Module): ...@@ -165,7 +162,6 @@ class GPTJBlock(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
mlp_output = self.mlp(hidden_states) mlp_output = self.mlp(hidden_states)
hidden_states = attn_output + mlp_output + residual hidden_states = attn_output + mlp_output + residual
...@@ -196,18 +192,15 @@ class GPTJModel(nn.Module): ...@@ -196,18 +192,15 @@ class GPTJModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
for i in range(len(self.h)): for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -238,10 +231,9 @@ class GPTJForCausalLM(nn.Module): ...@@ -238,10 +231,9 @@ class GPTJForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -92,14 +92,12 @@ class GPTNeoXAttention(nn.Module): ...@@ -92,14 +92,12 @@ class GPTNeoXAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states) qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.dense(attn_output) output, _ = self.dense(attn_output)
return output return output
...@@ -155,7 +153,6 @@ class GPTNeoXLayer(nn.Module): ...@@ -155,7 +153,6 @@ class GPTNeoXLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
attn_input = self.input_layernorm(hidden_states) attn_input = self.input_layernorm(hidden_states)
attn_output = self.attention( attn_output = self.attention(
...@@ -163,7 +160,6 @@ class GPTNeoXLayer(nn.Module): ...@@ -163,7 +160,6 @@ class GPTNeoXLayer(nn.Module):
hidden_states=attn_input, hidden_states=attn_input,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
if self.use_parallel_residual: if self.use_parallel_residual:
...@@ -210,18 +206,15 @@ class GPTNeoXModel(nn.Module): ...@@ -210,18 +206,15 @@ class GPTNeoXModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_in(input_ids) hidden_states = self.embed_in(input_ids)
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
) )
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
return hidden_states return hidden_states
...@@ -250,10 +243,9 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -250,10 +243,9 @@ class GPTNeoXForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -110,14 +110,12 @@ class InternLMAttention(nn.Module): ...@@ -110,14 +110,12 @@ class InternLMAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -160,7 +158,6 @@ class InternLMDecoderLayer(nn.Module): ...@@ -160,7 +158,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -175,7 +172,6 @@ class InternLMDecoderLayer(nn.Module): ...@@ -175,7 +172,6 @@ class InternLMDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# Fully Connected # Fully Connected
...@@ -214,19 +210,16 @@ class InternLMModel(nn.Module): ...@@ -214,19 +210,16 @@ class InternLMModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -253,10 +246,9 @@ class InternLMForCausalLM(nn.Module): ...@@ -253,10 +246,9 @@ class InternLMForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment