"llm/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "7c40a67841fd32073b66984e24605e5a0cc46f1a"
Unverified Commit 8bcf9c8d authored by Cyril Vallez's avatar Cyril Vallez Committed by GitHub
Browse files

Fix jetmoe model (#31279)

* Fix jetmoe model

* Remove skip-tests
parent f868cf73
...@@ -1404,7 +1404,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel): ...@@ -1404,7 +1404,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = ( max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
...@@ -1412,10 +1412,6 @@ class JetMoeForCausalLM(JetMoePreTrainedModel): ...@@ -1412,10 +1412,6 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
else None else None
) )
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1446,7 +1442,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel): ...@@ -1446,7 +1442,7 @@ class JetMoeForCausalLM(JetMoePreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
...@@ -472,14 +472,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -472,14 +472,6 @@ class JetMoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_flash_attn_2_inference_equivalence_right_padding(self): def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("JetMoe flash attention does not support right padding") self.skipTest("JetMoe flash attention does not support right padding")
@unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ")
def test_beam_sample_generate(self):
pass
@unittest.skip("TODO: @ArthurZucker - Breaks after #30536 ")
def test_generate_from_inputs_embeds_decoder_only(self):
pass
@require_torch @require_torch
class JetMoeIntegrationTest(unittest.TestCase): class JetMoeIntegrationTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment