Unverified Commit 99879600 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(router): make router input validation optional (#164)

parent 7dec65a2
......@@ -45,18 +45,19 @@ class FlashNeoXSharded(FlashNeoX):
raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config)
model = FlashGPTNeoXForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
......@@ -147,32 +148,3 @@ class FlashNeoXSharded(FlashNeoX):
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
cu_seqlens: torch.Tensor,
max_s: int,
past_key_values: Optional = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.model.gpt_neox.tp_embeddings:
logits, present = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_s=max_s,
past_key_values=past_key_values,
)
# Logits are sharded, so we need to gather them
world_logits = [torch.empty_like(logits) for _ in range(self.world_size)]
torch.distributed.all_gather(world_logits, logits, group=self.process_group)
world_logits = torch.cat(world_logits, dim=1)
return world_logits, present
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(FlashNeoXSharded, self).forward(
input_ids, position_ids, cu_seqlens, max_s, past_key_values
)
......@@ -33,7 +33,7 @@ class FlashSantacoder(FlashCausalLM):
raise NotImplementedError("FlashSantacoder does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(
......@@ -56,6 +56,8 @@ class FlashSantacoder(FlashCausalLM):
self.load_weights(
model,
filenames,
device,
dtype,
)
self.model = model.eval().to(device).to(dtype)
......@@ -68,10 +70,14 @@ class FlashSantacoder(FlashCausalLM):
def load_weights(
model: FlashSantacoderForCausalLM,
filenames: List[Path],
device: torch.device,
dtype: torch.dtype,
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device).to(dtype)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
......@@ -141,6 +147,8 @@ class FlashSantacoder(FlashCausalLM):
else:
module._buffers[param_name] = value
del value
torch.cuda.empty_cache()
model.post_load_weights()
......
......@@ -96,7 +96,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
input_lengths = []
# Parse batch
max_sequence_length = 0
max_truncation = 0
padding_right_offset = 0
for r in pb.requests:
# Add escape_custom_split_sequence to the CausalLMBatch logic
......@@ -107,7 +107,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_sequence_length = max(max_sequence_length, r.input_length)
max_truncation = max(max_truncation, r.truncate)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
......@@ -118,14 +118,20 @@ class GalacticaCausalLMBatch(CausalLMBatch):
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
input_ids = tokenized_inputs["input_ids"]
# Allocate maximum attention_mask
attention_mask = input_ids.new_zeros(
(pb.size, max_sequence_length + padding_right_offset)
(pb.size, max_input_length + padding_right_offset)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
......@@ -143,7 +149,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
size=pb.size,
max_sequence_length=max_sequence_length,
max_input_length=max_input_length,
padding_right_offset=padding_right_offset,
)
......@@ -188,7 +194,7 @@ class GalacticaSharded(Galactica):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(
......
......@@ -44,7 +44,7 @@ class GPTNeoxSharded(CausalLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
......
......@@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
tokenizer.add_special_tokens(
{
......
......@@ -73,6 +73,7 @@ class Seq2SeqLMBatch(Batch):
decoder_input_lengths = []
# Parse batch
max_truncation = 0
padding_right_offset = 0
for r in pb.requests:
inputs.append(r.inputs)
......@@ -84,6 +85,7 @@ class Seq2SeqLMBatch(Batch):
r.stopping_parameters, tokenizer
)
stopping_criterias.append(stopping_criteria)
max_truncation = max(max_truncation, r.truncate)
padding_right_offset = max(
padding_right_offset, stopping_criteria.max_new_tokens
)
......@@ -94,6 +96,8 @@ class Seq2SeqLMBatch(Batch):
return_tensors="pt",
padding=True,
return_token_type_ids=False,
truncation=True,
max_length=max_truncation,
).to(device)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
......
......@@ -44,7 +44,7 @@ class T5Sharded(Seq2SeqLM):
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left"
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = AutoConfig.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment