Unverified Commit 99879600 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(router): make router input validation optional (#164)

parent 7dec65a2
...@@ -45,18 +45,19 @@ class FlashNeoXSharded(FlashNeoX): ...@@ -45,18 +45,19 @@ class FlashNeoXSharded(FlashNeoX):
raise NotImplementedError("FlashNeoX does not support quantization") raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True model_id,
revision=revision,
) )
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors") filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights(): with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config) model = FlashGPTNeoXForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
self.load_weights( self.load_weights(
...@@ -147,32 +148,3 @@ class FlashNeoXSharded(FlashNeoX): ...@@ -147,32 +148,3 @@ class FlashNeoXSharded(FlashNeoX):
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_s: int,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.model.gpt_neox.tp_embeddings:
logits, present = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_s=max_s,
past_key_values=past_key_values,
)
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(FlashNeoXSharded, self).forward(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
)
...@@ -33,7 +33,7 @@ class FlashSantacoder(FlashCausalLM): ...@@ -33,7 +33,7 @@ class FlashSantacoder(FlashCausalLM):
raise NotImplementedError("FlashSantacoder does not support quantization") raise NotImplementedError("FlashSantacoder does not support quantization")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
...@@ -56,6 +56,8 @@ class FlashSantacoder(FlashCausalLM): ...@@ -56,6 +56,8 @@ class FlashSantacoder(FlashCausalLM):
self.load_weights( self.load_weights(
model, model,
filenames, filenames,
device,
dtype,
) )
self.model = model.eval().to(device).to(dtype) self.model = model.eval().to(device).to(dtype)
...@@ -68,10 +70,14 @@ class FlashSantacoder(FlashCausalLM): ...@@ -68,10 +70,14 @@ class FlashSantacoder(FlashCausalLM):
def load_weights( def load_weights(
model: FlashSantacoderForCausalLM, model: FlashSantacoderForCausalLM,
filenames: List[Path], filenames: List[Path],
device: torch.device,
dtype: torch.dtype,
): ):
for filename in filenames: for filename in filenames:
state_dict = torch.load(filename, map_location="cpu") state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items(): for key, value in state_dict.items():
value = value.to(device).to(dtype)
layer_name = ".".join(key.split(".")[:4]) layer_name = ".".join(key.split(".")[:4])
# Fused qkv # Fused qkv
...@@ -141,6 +147,8 @@ class FlashSantacoder(FlashCausalLM): ...@@ -141,6 +147,8 @@ class FlashSantacoder(FlashCausalLM):
else: else:
module._buffers[param_name] = value module._buffers[param_name] = value
del value
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights() model.post_load_weights()
......
...@@ -96,7 +96,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -96,7 +96,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
input_lengths = [] input_lengths = []
# Parse batch # Parse batch
max_sequence_length = 0 max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
...@@ -107,7 +107,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -107,7 +107,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_sequence_length = max(max_sequence_length, r.input_length) max_truncation = max(max_truncation, r.truncate)
padding_right_offset = max( padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
...@@ -118,14 +118,20 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -118,14 +118,20 @@ class GalacticaCausalLMBatch(CausalLMBatch):
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
return_token_type_ids=False, return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device) ).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"] input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask # Allocate maximum attention_mask
attention_mask = input_ids.new_zeros( attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset) (pb.size, max_input_length + padding_right_offset)
) )
# Copy tokenizer attention_mask into fully allocated attention_mask # Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"] attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
...@@ -143,7 +149,7 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -143,7 +149,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
size=pb.size, size=pb.size,
max_sequence_length=max_sequence_length, max_input_length=max_input_length,
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
) )
...@@ -188,7 +194,7 @@ class GalacticaSharded(Galactica): ...@@ -188,7 +194,7 @@ class GalacticaSharded(Galactica):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
......
...@@ -44,7 +44,7 @@ class GPTNeoxSharded(CausalLM): ...@@ -44,7 +44,7 @@ class GPTNeoxSharded(CausalLM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
......
...@@ -26,7 +26,7 @@ class SantaCoder(CausalLM): ...@@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
{ {
......
...@@ -73,6 +73,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -73,6 +73,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_lengths = [] decoder_input_lengths = []
# Parse batch # Parse batch
max_truncation = 0
padding_right_offset = 0 padding_right_offset = 0
for r in pb.requests: for r in pb.requests:
inputs.append(r.inputs) inputs.append(r.inputs)
...@@ -84,6 +85,7 @@ class Seq2SeqLMBatch(Batch): ...@@ -84,6 +85,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
) )
stopping_criterias.append(stopping_criteria) stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
padding_right_offset = max( padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens padding_right_offset, stopping_criteria.max_new_tokens
) )
...@@ -94,6 +96,8 @@ class Seq2SeqLMBatch(Batch): ...@@ -94,6 +96,8 @@ class Seq2SeqLMBatch(Batch):
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
return_token_type_ids=False, return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device) ).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
......
...@@ -44,7 +44,7 @@ class T5Sharded(Seq2SeqLM): ...@@ -44,7 +44,7 @@ class T5Sharded(Seq2SeqLM):
dtype = torch.float32 dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment