Unverified Commit 37ca5581 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Optimize model execution with CUDA graph (#1926)


Co-authored-by: default avatarChen Shen <scv119@gmail.com>
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent eed74a55
...@@ -147,14 +147,12 @@ class LlamaAttention(nn.Module): ...@@ -147,14 +147,12 @@ class LlamaAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -198,7 +196,6 @@ class LlamaDecoderLayer(nn.Module): ...@@ -198,7 +196,6 @@ class LlamaDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -213,7 +210,6 @@ class LlamaDecoderLayer(nn.Module): ...@@ -213,7 +210,6 @@ class LlamaDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# Fully Connected # Fully Connected
...@@ -250,19 +246,16 @@ class LlamaModel(nn.Module): ...@@ -250,19 +246,16 @@ class LlamaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -289,10 +282,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -289,10 +282,9 @@ class LlamaForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -145,14 +145,12 @@ class MistralAttention(nn.Module): ...@@ -145,14 +145,12 @@ class MistralAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -193,7 +191,6 @@ class MistralDecoderLayer(nn.Module): ...@@ -193,7 +191,6 @@ class MistralDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -208,7 +205,6 @@ class MistralDecoderLayer(nn.Module): ...@@ -208,7 +205,6 @@ class MistralDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# Fully Connected # Fully Connected
...@@ -246,19 +242,16 @@ class MistralModel(nn.Module): ...@@ -246,19 +242,16 @@ class MistralModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -285,10 +278,9 @@ class MistralForCausalLM(nn.Module): ...@@ -285,10 +278,9 @@ class MistralForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -253,14 +253,12 @@ class MixtralAttention(nn.Module): ...@@ -253,14 +253,12 @@ class MixtralAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -297,7 +295,6 @@ class MixtralDecoderLayer(nn.Module): ...@@ -297,7 +295,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -312,7 +309,6 @@ class MixtralDecoderLayer(nn.Module): ...@@ -312,7 +309,6 @@ class MixtralDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# Fully Connected # Fully Connected
...@@ -349,16 +345,14 @@ class MixtralModel(nn.Module): ...@@ -349,16 +345,14 @@ class MixtralModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> SamplerOutput:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata, kv_caches[i], input_metadata,
cache_event, residual) residual)
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -383,10 +377,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -383,10 +377,9 @@ class MixtralForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -117,7 +117,6 @@ class MPTAttention(nn.Module): ...@@ -117,7 +117,6 @@ class MPTAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
del position_ids # unused. del position_ids # unused.
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
...@@ -128,8 +127,7 @@ class MPTAttention(nn.Module): ...@@ -128,8 +127,7 @@ class MPTAttention(nn.Module):
q = self.q_ln(q) q = self.q_ln(q)
k = self.k_ln(k) k = self.k_ln(k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -187,7 +185,6 @@ class MPTBlock(nn.Module): ...@@ -187,7 +185,6 @@ class MPTBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
x = self.norm_1(hidden_states) x = self.norm_1(hidden_states)
x = self.attn( x = self.attn(
...@@ -195,7 +192,6 @@ class MPTBlock(nn.Module): ...@@ -195,7 +192,6 @@ class MPTBlock(nn.Module):
hidden_states=x, hidden_states=x,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
hidden_states = hidden_states + x hidden_states = hidden_states + x
x = self.norm_2(hidden_states) x = self.norm_2(hidden_states)
...@@ -235,18 +231,15 @@ class MPTModel(nn.Module): ...@@ -235,18 +231,15 @@ class MPTModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)): for i in range(len(self.blocks)):
cache_event = None if cache_events is None else cache_events[i]
block = self.blocks[i] block = self.blocks[i]
hidden_states = block( hidden_states = block(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
) )
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
return hidden_states return hidden_states
...@@ -274,10 +267,9 @@ class MPTForCausalLM(nn.Module): ...@@ -274,10 +267,9 @@ class MPTForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -98,13 +98,12 @@ class OPTAttention(nn.Module): ...@@ -98,13 +98,12 @@ class OPTAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
key_cache, value_cache = kv_cache key_cache, value_cache = kv_cache
attn_output = self.attn(q, k, v, key_cache, value_cache, attn_output = self.attn(q, k, v, key_cache, value_cache,
input_metadata, cache_event) input_metadata)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -154,7 +153,6 @@ class OPTDecoderLayer(nn.Module): ...@@ -154,7 +153,6 @@ class OPTDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -163,8 +161,7 @@ class OPTDecoderLayer(nn.Module): ...@@ -163,8 +161,7 @@ class OPTDecoderLayer(nn.Module):
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states, hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata)
cache_event=cache_event)
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
# 350m applies layer norm AFTER attention # 350m applies layer norm AFTER attention
if not self.do_layer_norm_before: if not self.do_layer_norm_before:
...@@ -245,7 +242,6 @@ class OPTDecoder(nn.Module): ...@@ -245,7 +242,6 @@ class OPTDecoder(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
pos_embeds = self.embed_positions(positions) pos_embeds = self.embed_positions(positions)
...@@ -254,10 +250,8 @@ class OPTDecoder(nn.Module): ...@@ -254,10 +250,8 @@ class OPTDecoder(nn.Module):
hidden_states = inputs_embeds + pos_embeds hidden_states = inputs_embeds + pos_embeds
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer(hidden_states, kv_caches[i], input_metadata, hidden_states = layer(hidden_states, kv_caches[i], input_metadata)
cache_event)
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
...@@ -282,10 +276,8 @@ class OPTModel(nn.Module): ...@@ -282,10 +276,8 @@ class OPTModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
return self.decoder(input_ids, positions, kv_caches, input_metadata, return self.decoder(input_ids, positions, kv_caches, input_metadata)
cache_events)
class OPTForCausalLM(nn.Module): class OPTForCausalLM(nn.Module):
...@@ -308,10 +300,9 @@ class OPTForCausalLM(nn.Module): ...@@ -308,10 +300,9 @@ class OPTForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -135,14 +135,12 @@ class PhiAttention(nn.Module): ...@@ -135,14 +135,12 @@ class PhiAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -195,7 +193,6 @@ class PhiLayer(nn.Module): ...@@ -195,7 +193,6 @@ class PhiLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln(hidden_states) hidden_states = self.ln(hidden_states)
...@@ -204,7 +201,6 @@ class PhiLayer(nn.Module): ...@@ -204,7 +201,6 @@ class PhiLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
feed_forward_hidden_states = self.mlp(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual hidden_states = attn_outputs + feed_forward_hidden_states + residual
...@@ -231,18 +227,15 @@ class PhiModel(nn.Module): ...@@ -231,18 +227,15 @@ class PhiModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embd(input_ids) hidden_states = self.embd(input_ids)
for i in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
) )
return hidden_states return hidden_states
...@@ -277,10 +270,9 @@ class PhiForCausalLM(nn.Module): ...@@ -277,10 +270,9 @@ class PhiForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
hidden_states = self.lm_head.ln(hidden_states) hidden_states = self.lm_head.ln(hidden_states)
return hidden_states return hidden_states
......
...@@ -112,14 +112,12 @@ class QWenAttention(nn.Module): ...@@ -112,14 +112,12 @@ class QWenAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.c_proj(attn_output) output, _ = self.c_proj(attn_output)
return output return output
...@@ -156,7 +154,6 @@ class QWenBlock(nn.Module): ...@@ -156,7 +154,6 @@ class QWenBlock(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -170,7 +167,6 @@ class QWenBlock(nn.Module): ...@@ -170,7 +167,6 @@ class QWenBlock(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# Fully Connected # Fully Connected
...@@ -206,19 +202,16 @@ class QWenModel(nn.Module): ...@@ -206,19 +202,16 @@ class QWenModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
residual = None residual = None
for i in range(len(self.h)): for i in range(len(self.h)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.h[i] layer = self.h[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
residual, residual,
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
...@@ -245,10 +238,9 @@ class QWenLMHeadModel(nn.Module): ...@@ -245,10 +238,9 @@ class QWenLMHeadModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
...@@ -146,14 +146,12 @@ class YiAttention(nn.Module): ...@@ -146,14 +146,12 @@ class YiAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -195,7 +193,6 @@ class YiDecoderLayer(nn.Module): ...@@ -195,7 +193,6 @@ class YiDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -209,7 +206,6 @@ class YiDecoderLayer(nn.Module): ...@@ -209,7 +206,6 @@ class YiDecoderLayer(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_event=cache_event,
) )
# Fully Connected # Fully Connected
...@@ -245,19 +241,16 @@ class YiModel(nn.Module): ...@@ -245,19 +241,16 @@ class YiModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
cache_event = None if cache_events is None else cache_events[i]
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
cache_event,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -284,10 +277,9 @@ class YiForCausalLM(nn.Module): ...@@ -284,10 +277,9 @@ class YiForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata)
return hidden_states return hidden_states
def sample( def sample(
......
import torch import torch
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
is_custom_nccl_enabled_for_all_reduce,
) )
...@@ -15,6 +17,10 @@ def tensor_model_parallel_all_reduce(input_): ...@@ -15,6 +17,10 @@ def tensor_model_parallel_all_reduce(input_):
if get_tensor_model_parallel_world_size() == 1: if get_tensor_model_parallel_world_size() == 1:
return input_ return input_
# All-reduce. # All-reduce.
if is_custom_nccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
cupy_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_, torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group()) group=get_tensor_model_parallel_group())
return input_ return input_
......
"""CuPy utilities for all-reduce.
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.
TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import contextlib
import torch
from torch.distributed import ReduceOp
try:
import cupy
from cupyx.distributed import NCCLBackend
from cupy.cuda import nccl
except ImportError as e:
cupy = e
nccl = None
class NCCLBackend:
...
_OP_MAPPING = {
ReduceOp.SUM: "sum",
ReduceOp.PRODUCT: "prod",
ReduceOp.MIN: "min",
ReduceOp.MAX: "max",
}
class NCCLBackendWithBFloat16(NCCLBackend):
# This is enough to add bfloat16 support for most operations,
# but broadcast will fail (will require changes in compiled
# cupy code).
def _get_nccl_dtype_and_count(self, array, count=None):
nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count)
torch_dtype = getattr(array, "_torch_dtype", None)
if torch_dtype is torch.bfloat16:
nccl_dtype = nccl.NCCL_BFLOAT16
return nccl_dtype, count
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def is_initialized() -> bool:
"""Returns whether the NCCL backend is initialized."""
return _NCCL_BACKEND is not None
@contextlib.contextmanager
def set_cupy_stream(stream: torch.cuda.Stream) -> None:
"""Set the cuda stream for communication"""
cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream,
stream.device_index)
with cupy_stream:
yield
def init_process_group(world_size: int, rank: int, host: str,
port: int) -> None:
"""Initializes the CuPy NCCL backend.
# TODO: handle NCCL timeouts.
"""
assert not is_initialized()
if isinstance(cupy, Exception):
raise ImportError(
"NCCLBackend is not available. Please install cupy.") from cupy
# TODO(woosuk): Create TP and PP process groups for CuPy.
global _NCCL_BACKEND
global _WORLD_SIZE
assert world_size > 0, f"{world_size=} should be a positive integer"
assert 0 <= rank < world_size, (
f"{rank=} should be a integer between [0, {world_size})")
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
_WORLD_SIZE = world_size
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
# Hack to support bfloat16
torch_dtype = input_.dtype
if torch_dtype is torch.bfloat16:
# We need to view as float16, otherwise
# cupy will fail. This will not change
# the underlying data.
input_ = input_.view(torch.float16)
cupy_input = cupy.asarray(input_)
cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access
_NCCL_BACKEND.all_reduce(in_array=cupy_input,
out_array=cupy_input,
op=_OP_MAPPING[op])
def destroy_process_group() -> None:
"""Destroys the NCCL backend."""
global _NCCL_BACKEND
global _WORLD_SIZE
_NCCL_BACKEND = None
_WORLD_SIZE = 0
def get_world_size() -> int:
"""Returns the world size."""
return _WORLD_SIZE
...@@ -3,9 +3,12 @@ ...@@ -3,9 +3,12 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Tensor and pipeline parallel groups.""" """Tensor and pipeline parallel groups."""
import contextlib
import torch import torch
from vllm.model_executor.parallel_utils import cupy_utils
# Tensor model parallel group that the current rank belongs to. # Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None _TENSOR_MODEL_PARALLEL_GROUP = None
# Pipeline model parallel group that the current rank belongs to. # Pipeline model parallel group that the current rank belongs to.
...@@ -177,3 +180,37 @@ def destroy_model_parallel(): ...@@ -177,3 +180,37 @@ def destroy_model_parallel():
_PIPELINE_MODEL_PARALLEL_GROUP = None _PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None _PIPELINE_GLOBAL_RANKS = None
# Destroy the cupy states if any.
cupy_utils.destroy_process_group()
# Whether to use cupy for nccl all reduce.
# We use cupy for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_CUPY_FOR_ALL_REDUCE = False
@contextlib.contextmanager
def with_custom_nccl_for_all_reduce():
"""use custom nccl instead of torch.distributed for all reduce"""
tp_size = get_tensor_model_parallel_world_size()
if tp_size == 1:
# No-op.
# NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
yield
else:
global _ENABLE_CUPY_FOR_ALL_REDUCE
old = _ENABLE_CUPY_FOR_ALL_REDUCE
_ENABLE_CUPY_FOR_ALL_REDUCE = True
stream = torch.cuda.current_stream()
with cupy_utils.set_cupy_stream(stream):
yield
_ENABLE_CUPY_FOR_ALL_REDUCE = old
def is_custom_nccl_enabled_for_all_reduce():
"""check if custom nccl is enabled for all reduce"""
global _ENABLE_CUPY_FOR_ALL_REDUCE
return _ENABLE_CUPY_FOR_ALL_REDUCE
import enum import enum
import socket
import uuid import uuid
from platform import uname from platform import uname
...@@ -52,3 +53,9 @@ def random_uuid() -> str: ...@@ -52,3 +53,9 @@ def random_uuid() -> str:
def in_wsl() -> bool: def in_wsl() -> bool:
# Reference: https://github.com/microsoft/WSL/issues/4071 # Reference: https://github.com/microsoft/WSL/issues/4071
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
def get_open_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
from typing import Dict, List, Optional, Tuple import time
from typing import Dict, List, Tuple, Union
import numpy as np
import torch import torch
import torch.nn as nn
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.parallel_state import (
with_custom_nccl_for_all_reduce)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1 _PAD_SLOT_ID = -1
# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
class ModelRunner: class ModelRunner:
...@@ -32,12 +41,31 @@ class ModelRunner: ...@@ -32,12 +41,31 @@ class ModelRunner:
self.model = None self.model = None
self.block_size = None # Set after initial profiling. self.block_size = None # Set after initial profiling.
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
self.graph_memory_pool = None # Set during graph capture.
self.max_context_len_to_capture = (
self.model_config.max_context_len_to_capture
if self.model_config is not None else 0)
# When using CUDA graph, the input block tables must be padded to
# max_context_len_to_capture. However, creating the block table in
# Python can be expensive. To optimize this, we cache the block table
# in numpy and only copy the actual input content at every iteration.
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(self.model_config) self.model = get_model(self.model_config)
def set_block_size(self, block_size: int) -> None: def set_block_size(self, block_size: int) -> None:
self.block_size = block_size self.block_size = block_size
max_num_blocks = (self.max_context_len_to_capture + block_size -
1) // block_size
self.graph_block_tables = np.zeros(
(max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32)
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
...@@ -111,6 +139,7 @@ class ModelRunner: ...@@ -111,6 +139,7 @@ class ModelRunner:
max_context_len=None, max_context_len=None,
context_lens=None, context_lens=None,
block_tables=None, block_tables=None,
use_cuda_graph=False,
) )
return input_tokens, input_positions, input_metadata return input_tokens, input_positions, input_metadata
...@@ -154,27 +183,62 @@ class ModelRunner: ...@@ -154,27 +183,62 @@ class ModelRunner:
block_table = block_table[-sliding_window_blocks:] block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table) block_tables.append(block_table)
batch_size = len(input_tokens)
max_context_len = max(context_lens)
use_captured_graph = (
not self.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_context_len <= self.max_context_len_to_capture)
if use_captured_graph:
# Pad the input tokens, positions, and slot mapping to match the
# batch size of the captured graph.
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
for _ in range(graph_batch_size - batch_size):
input_tokens.append([])
input_positions.append([])
slot_mapping.append([])
context_lens.append(1)
block_tables.append([])
batch_size = graph_batch_size
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device = "cpu" if use_captured_graph else "cuda"
input_tokens = _make_tensor_with_pad(input_tokens, input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long) dtype=torch.long,
device=device)
input_positions = _make_tensor_with_pad(input_positions, input_positions = _make_tensor_with_pad(input_positions,
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long) dtype=torch.long,
device=device)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1, max_len=1,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long) dtype=torch.long,
max_context_len = max(context_lens) device=device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device="cuda") device=device)
max_block_table_len = max([len(t) for t in block_tables])
block_tables = _make_tensor_with_pad(block_tables, if use_captured_graph:
max_len=max_block_table_len, # The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.graph_block_tables[:batch_size]
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(device)
else:
block_tables = _make_tensor_with_pad(
block_tables,
max_len=max_context_len,
pad=0, pad=0,
dtype=torch.int) dtype=torch.int,
)
input_metadata = InputMetadata( input_metadata = InputMetadata(
prompt_lens=[], prompt_lens=[],
...@@ -182,6 +246,7 @@ class ModelRunner: ...@@ -182,6 +246,7 @@ class ModelRunner:
max_context_len=max_context_len, max_context_len=max_context_len,
context_lens=context_lens, context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
use_cuda_graph=use_captured_graph,
) )
return input_tokens, input_positions, input_metadata return input_tokens, input_positions, input_metadata
...@@ -260,12 +325,11 @@ class ModelRunner: ...@@ -260,12 +325,11 @@ class ModelRunner:
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
cache_events: Optional[List[torch.cuda.Event]] = None,
) -> SamplerOutput: ) -> SamplerOutput:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
# all decodes. # all decodes.
# Prepare input tensors.
is_prompt = seq_group_metadata_list[0].is_prompt is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt: if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list) inputs = self._prepare_prompt(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs input_tokens, input_positions, input_metadata = inputs
...@@ -276,12 +340,16 @@ class ModelRunner: ...@@ -276,12 +340,16 @@ class ModelRunner:
input_metadata.prompt_lens) input_metadata.prompt_lens)
# Execute the model. # Execute the model.
hidden_states = self.model( if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
model_executable = self.model
hidden_states = model_executable(
input_ids=input_tokens, input_ids=input_tokens,
positions=input_positions, positions=input_positions,
kv_caches=kv_caches, kv_caches=kv_caches,
input_metadata=input_metadata, input_metadata=input_metadata,
cache_events=cache_events,
) )
# Sample the next token. # Sample the next token.
...@@ -319,8 +387,139 @@ class ModelRunner: ...@@ -319,8 +387,139 @@ class ModelRunner:
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [(None, None)] * num_layers kv_caches = [(None, None)] * num_layers
self.execute_model(seqs, kv_caches) self.execute_model(seqs, kv_caches)
torch.cuda.synchronize()
return return
@torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None:
assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to "
"unexpected consequences if the model is not static. To "
"run the model in eager mode, set 'enforce_eager=True' or "
"use '--enforce-eager' in the CLI.")
start_time = time.perf_counter()
# Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda()
input_positions = torch.zeros(max_batch_size, 1,
dtype=torch.long).cuda()
slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda()
slot_mapping.fill_(_PAD_SLOT_ID)
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda()
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
# Create dummy input_metadata.
input_metadata = InputMetadata(
prompt_lens=[],
slot_mapping=slot_mapping[:batch_size],
max_context_len=self.max_context_len_to_capture,
context_lens=context_lens[:batch_size],
block_tables=block_tables[:batch_size],
use_cuda_graph=True,
)
graph_runner = CUDAGraphRunner(self.model)
graph_runner.capture(
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
input_metadata,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[batch_size] = graph_runner
end_time = time.perf_counter()
elapsed_time = end_time - start_time
# This usually takes < 10 seconds.
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
class CUDAGraphRunner:
def __init__(self, model: nn.Module):
self.model = model
self.graph = None
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
def capture(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
memory_pool,
) -> None:
assert self.graph is None
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with with_custom_nccl_for_all_reduce():
self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with with_custom_nccl_for_all_reduce():
hidden_states = self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()
# Save the input and output buffers.
self.input_buffers = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": input_metadata.slot_mapping,
"context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
input_metadata: InputMetadata,
) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them.
del kv_caches
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids)
self.input_buffers["positions"].copy_(positions)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables)
# Run the graph.
self.graph.replay()
# Return the output tensor.
return self.output_buffers["hidden_states"]
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len assert len(x) <= max_len
...@@ -332,6 +531,16 @@ def _make_tensor_with_pad( ...@@ -332,6 +531,16 @@ def _make_tensor_with_pad(
max_len: int, max_len: int,
pad: int, pad: int,
dtype: torch.dtype, dtype: torch.dtype,
device: Union[str, torch.device] = "cuda",
) -> torch.Tensor: ) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device="cuda") return torch.tensor(padded_x, dtype=dtype, device=device)
def _get_graph_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return (batch_size + 7) // 8 * 8
...@@ -8,6 +8,7 @@ import torch.distributed ...@@ -8,6 +8,7 @@ import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel) initialize_model_parallel)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
...@@ -46,7 +47,7 @@ class Worker: ...@@ -46,7 +47,7 @@ class Worker:
self.cache_events = None self.cache_events = None
self.gpu_cache = None self.gpu_cache = None
def init_model(self): def init_model(self, cupy_port: Optional[int] = None):
# This env var set by Ray causes exceptions with graph building. # This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray. # Env vars will be set by Ray.
...@@ -62,7 +63,7 @@ class Worker: ...@@ -62,7 +63,7 @@ class Worker:
# Initialize the distributed environment. # Initialize the distributed environment.
_init_distributed_environment(self.parallel_config, self.rank, _init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method) cupy_port, self.distributed_init_method)
# Initialize the model. # Initialize the model.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
...@@ -100,10 +101,6 @@ class Worker: ...@@ -100,10 +101,6 @@ class Worker:
num_gpu_blocks = max(num_gpu_blocks, 0) num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0)
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None: def init_cache_engine(self, cache_config: CacheConfig) -> None:
...@@ -114,6 +111,13 @@ class Worker: ...@@ -114,6 +111,13 @@ class Worker:
self.gpu_cache = self.cache_engine.gpu_cache self.gpu_cache = self.cache_engine.gpu_cache
self.model_runner.set_block_size(self.cache_engine.block_size) self.model_runner.set_block_size(self.cache_engine.block_size)
def warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
...@@ -136,21 +140,24 @@ class Worker: ...@@ -136,21 +140,24 @@ class Worker:
cache_events = self.cache_events if issued_cache_op else None cache_events = self.cache_events if issued_cache_op else None
# If there is no input, we don't need to execute the model. # Wait for cache operations to finish.
if not seq_group_metadata_list: # TODO(woosuk): Profile swapping overhead and optimize if needed.
if cache_events is not None: if cache_events is not None:
for event in cache_events: for event in cache_events:
event.wait() event.wait()
# If there is no input, we don't need to execute the model.
if not seq_group_metadata_list:
return {} return {}
output = self.model_runner.execute_model(seq_group_metadata_list, output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache, cache_events) self.gpu_cache)
return output return output
def _init_distributed_environment( def _init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
cupy_port: Optional[int],
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
...@@ -173,8 +180,29 @@ def _init_distributed_environment( ...@@ -173,8 +180,29 @@ def _init_distributed_environment(
init_method=distributed_init_method, init_method=distributed_init_method,
) )
if cupy_utils.is_initialized():
cupy_world_size = cupy_utils.get_world_size()
if cupy_world_size != parallel_config.world_size:
raise RuntimeError(
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
cupy_utils.init_process_group(
world_size=parallel_config.world_size,
rank=rank,
host="localhost",
port=cupy_port,
)
if parallel_config.world_size > 1:
# A small all_reduce for warmup. # A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda()) torch.distributed.all_reduce(torch.zeros(1).cuda())
cupy_utils.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size, initialize_model_parallel(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment