"...resnet50_tensorflow.git" did not exist on "a0be68855440874f5086f73c87c0da64bf554594"
Unverified Commit f5806041 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Fix PyTorch import error (#11839)

* fix_torch_device_generate_test

* remove @

* change pytorch import to flax import
parent 0cbddfb1
...@@ -42,7 +42,7 @@ from flax.training import train_state ...@@ -42,7 +42,7 @@ from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard from flax.training.common_utils import get_metrics, onehot, shard
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING, FLAX_MODEL_FOR_MASKED_LM_MAPPING,
AutoConfig, AutoConfig,
AutoTokenizer, AutoTokenizer,
FlaxAutoModelForMaskedLM, FlaxAutoModelForMaskedLM,
...@@ -71,7 +71,7 @@ else: ...@@ -71,7 +71,7 @@ else:
) )
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment