Unverified Commit ad66f6ef authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): optim flash causal lm decode_token (#285)

parent bc5c0723
...@@ -554,6 +554,7 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -554,6 +554,7 @@ class FlashLlamaModel(torch.nn.Module):
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
...@@ -575,15 +576,11 @@ class FlashLlamaModel(torch.nn.Module): ...@@ -575,15 +576,11 @@ class FlashLlamaModel(torch.nn.Module):
) )
) )
layer_past_present_indices = None layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states) slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths # Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1 layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
...@@ -650,6 +647,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -650,6 +647,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
...@@ -658,6 +656,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): ...@@ -658,6 +656,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q,
max_s, max_s,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
......
...@@ -617,6 +617,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -617,6 +617,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q,
max_s, max_s,
past_key_values=None, past_key_values=None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
...@@ -638,15 +639,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): ...@@ -638,15 +639,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
) )
) )
layer_past_present_indices = None layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states) slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths # Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1 layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None slice_past_index = None
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
...@@ -726,6 +723,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -726,6 +723,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
...@@ -734,6 +732,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): ...@@ -734,6 +732,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q,
max_s, max_s,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
......
...@@ -484,6 +484,7 @@ class FlashSantacoderModel(nn.Module): ...@@ -484,6 +484,7 @@ class FlashSantacoderModel(nn.Module):
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
...@@ -507,15 +508,11 @@ class FlashSantacoderModel(nn.Module): ...@@ -507,15 +508,11 @@ class FlashSantacoderModel(nn.Module):
) )
) )
layer_past_present_indices = None layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states) slice_past_index = len(hidden_states)
# Decode # Decode
else: else:
# Create indices from cumulative sequence lengths # Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1 layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None slice_past_index = None
residual = None residual = None
...@@ -566,6 +563,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -566,6 +563,7 @@ class FlashSantacoderForCausalLM(nn.Module):
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q,
max_s, max_s,
past_key_values: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None, pre_allocate_past_size: Optional[int] = None,
...@@ -574,6 +572,7 @@ class FlashSantacoderForCausalLM(nn.Module): ...@@ -574,6 +572,7 @@ class FlashSantacoderForCausalLM(nn.Module):
input_ids, input_ids,
position_ids, position_ids,
cu_seqlens, cu_seqlens,
cu_seqlens_q,
max_s, max_s,
past_key_values, past_key_values,
pre_allocate_past_size, pre_allocate_past_size,
......
...@@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM): ...@@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM):
self.past_pad = None self.past_pad = None
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
...@@ -161,7 +161,7 @@ class FlashLlamaSharded(FlashLlama): ...@@ -161,7 +161,7 @@ class FlashLlamaSharded(FlashLlama):
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
......
...@@ -38,7 +38,7 @@ class FlashNeoXSharded(FlashNeoX): ...@@ -38,7 +38,7 @@ class FlashNeoXSharded(FlashNeoX):
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")
......
...@@ -31,7 +31,7 @@ class FlashSantacoder(FlashCausalLM): ...@@ -31,7 +31,7 @@ class FlashSantacoder(FlashCausalLM):
self.past_pad = None self.past_pad = None
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashSantacoder is only available on GPU") raise NotImplementedError("FlashSantacoder is only available on GPU")
...@@ -178,7 +178,7 @@ class FlashSantacoderSharded(FlashSantacoder): ...@@ -178,7 +178,7 @@ class FlashSantacoderSharded(FlashSantacoder):
self.master = self.rank == 0 self.master = self.rank == 0
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}") device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 dtype = torch.float16
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment