Unverified Commit f9958ee1 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing cohere tokenizer. (#1697)

parent 5062fda4
...@@ -3,7 +3,7 @@ import torch.distributed ...@@ -3,7 +3,7 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from typing import Optional from typing import Optional
from transformers.models.llama import LlamaTokenizerFast from transformers import AutoTokenizer
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( from text_generation_server.models.custom_modeling.flash_cohere_modeling import (
...@@ -36,7 +36,7 @@ class FlashCohere(FlashCausalLM): ...@@ -36,7 +36,7 @@ class FlashCohere(FlashCausalLM):
else: else:
raise NotImplementedError("FlashCohere is only available on GPU") raise NotImplementedError("FlashCohere is only available on GPU")
tokenizer = LlamaTokenizerFast.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment