Unverified Commit ad66f6ef authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): optim flash causal lm decode_token (#285)

parent bc5c0723
......@@ -554,6 +554,7 @@ class FlashLlamaModel(torch.nn.Module):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
......@@ -575,15 +576,11 @@ class FlashLlamaModel(torch.nn.Module):
)
)
layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None
# Get rotary cos and sin for this forward
......@@ -650,6 +647,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
......@@ -658,6 +656,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values,
pre_allocate_past_size,
......
......@@ -617,6 +617,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values=None,
pre_allocate_past_size: Optional[int] = None,
......@@ -638,15 +639,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
)
)
layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None
# Get rotary cos and sin for this forward
......@@ -726,6 +723,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
......@@ -734,6 +732,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values,
pre_allocate_past_size,
......
......@@ -484,6 +484,7 @@ class FlashSantacoderModel(nn.Module):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
......@@ -507,15 +508,11 @@ class FlashSantacoderModel(nn.Module):
)
)
layer_past_present_indices = None
cu_seqlens_q = None
slice_past_index = len(hidden_states)
# Decode
else:
# Create indices from cumulative sequence lengths
layer_past_present_indices = cu_seqlens[1:] - 1
cu_seqlens_q = torch.arange(
cu_seqlens.shape[0], dtype=torch.int32, device=hidden_states.device
)
slice_past_index = None
residual = None
......@@ -566,6 +563,7 @@ class FlashSantacoderForCausalLM(nn.Module):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values: Optional[torch.Tensor] = None,
pre_allocate_past_size: Optional[int] = None,
......@@ -574,6 +572,7 @@ class FlashSantacoderForCausalLM(nn.Module):
input_ids,
position_ids,
cu_seqlens,
cu_seqlens_q,
max_s,
past_key_values,
pre_allocate_past_size,
......
......@@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM):
self.past_pad = None
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashLlama is only available on GPU")
......@@ -161,7 +161,7 @@ class FlashLlamaSharded(FlashLlama):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashLlama is only available on GPU")
......
......@@ -38,7 +38,7 @@ class FlashNeoXSharded(FlashNeoX):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashNeoX is only available on GPU")
......
......@@ -31,7 +31,7 @@ class FlashSantacoder(FlashCausalLM):
self.past_pad = None
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashSantacoder is only available on GPU")
......@@ -178,7 +178,7 @@ class FlashSantacoderSharded(FlashSantacoder):
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
dtype = torch.float16
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment