Unverified Commit 304c6a1e authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Enable fx tracing for Mistral (#30209)

* tracing for mistral

* typo

* fix copies
parent 98717cb3
...@@ -868,9 +868,6 @@ class MixtralSparseMoeBlock(nn.Module): ...@@ -868,9 +868,6 @@ class MixtralSparseMoeBlock(nn.Module):
expert_layer = self.experts[expert_idx] expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx]) idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
# Index the correct hidden states and compute the expert hidden state for # Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden # the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2) # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
......
...@@ -840,9 +840,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module): ...@@ -840,9 +840,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
expert_layer = self.experts[expert_idx] expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx]) idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.shape[0] == 0:
continue
# Index the correct hidden states and compute the expert hidden state for # Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden # the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2) # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
......
...@@ -141,12 +141,16 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [ ...@@ -141,12 +141,16 @@ _REGULAR_SUPPORTED_MODEL_NAMES_AND_TASKS = [
"marian", "marian",
"mbart", "mbart",
"megatron-bert", "megatron-bert",
"mistral",
"mixtral",
"mobilebert", "mobilebert",
"mt5", "mt5",
"nezha", "nezha",
"opt", "opt",
"pegasus", "pegasus",
"plbart", "plbart",
"qwen2",
"qwen2_moe",
"resnet", "resnet",
"roberta", "roberta",
"segformer", "segformer",
...@@ -758,6 +762,7 @@ class HFTracer(Tracer): ...@@ -758,6 +762,7 @@ class HFTracer(Tracer):
"tensor", "tensor",
"clamp", "clamp",
"finfo", "finfo",
"tril",
] ]
supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) supported_archs = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
......
...@@ -303,6 +303,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -303,6 +303,7 @@ class MistralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
) )
test_headmasking = False test_headmasking = False
test_pruning = False test_pruning = False
fx_compatible = True
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
......
...@@ -302,6 +302,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -302,6 +302,7 @@ class MixtralModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
) )
test_headmasking = False test_headmasking = False
test_pruning = False test_pruning = False
fx_compatible = True
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
......
...@@ -313,6 +313,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -313,6 +313,7 @@ class Qwen2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
) )
test_headmasking = False test_headmasking = False
test_pruning = False test_pruning = False
fx_compatible = True
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
......
...@@ -342,6 +342,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ...@@ -342,6 +342,7 @@ class Qwen2MoeModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
) )
test_headmasking = False test_headmasking = False
test_pruning = False test_pruning = False
fx_compatible = True
# TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146 # TODO (ydshieh): Check this. See https://app.circleci.com/pipelines/github/huggingface/transformers/79245/workflows/9490ef58-79c2-410d-8f51-e3495156cf9c/jobs/1012146
def is_pipeline_test_to_skip( def is_pipeline_test_to_skip(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment