Unverified Commit 66914f7b authored by SeongBeomLEE's avatar SeongBeomLEE Committed by GitHub
Browse files

fix: LlamaTokenizerFast to AutoTokenizer at flash_mistral.py (#1637)

# What does this PR do?

A few cases where you're using a mistral structure or mixtral structure
but not a llama tokenizer, why not make it to call the AutoTokenizer in
exception handling.

Similar PR #619

@Narsil
parent 08e91814
......@@ -6,7 +6,7 @@ import numpy as np
from dataclasses import dataclass
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, AutoTokenizer
from transformers.models.llama import LlamaTokenizerFast
from typing import Optional, Tuple, Type
......@@ -317,13 +317,22 @@ class BaseFlashMistral(FlashCausalLM):
else:
raise NotImplementedError("FlashMistral is only available on GPU")
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
try:
tokenizer = LlamaTokenizerFast.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
config = config_cls.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment