Unverified Commit fb2d0680 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Fix] Fix clean_up_tokenization_spaces in tokenizer (#1510)

parent 067d8e16
...@@ -129,6 +129,7 @@ def get_tokenizer( ...@@ -129,6 +129,7 @@ def get_tokenizer(
*args, *args,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
tokenizer_revision=tokenizer_revision, tokenizer_revision=tokenizer_revision,
clean_up_tokenization_spaces=False,
**kwargs, **kwargs,
) )
except TypeError as e: except TypeError as e:
......
...@@ -21,8 +21,9 @@ from typing import List, Union ...@@ -21,8 +21,9 @@ from typing import List, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Runtime from sglang.srt.server import Runtime
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
...@@ -92,11 +93,7 @@ class HFRunner: ...@@ -92,11 +93,7 @@ class HFRunner:
self.model_proc.start() self.model_proc.start()
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
self.tokenizer = AutoTokenizer.from_pretrained( self.tokenizer = get_tokenizer(model_path)
model_path,
torch_dtype=torch_dtype,
)
if self.is_generation: if self.is_generation:
self.base_model = AutoModelForCausalLM.from_pretrained( self.base_model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
......
...@@ -26,12 +26,14 @@ I'm going to the ...@@ -26,12 +26,14 @@ I'm going to the
import argparse import argparse
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM
from sglang.srt.hf_transformers_utils import get_tokenizer
@torch.inference_mode() @torch.inference_mode()
def normal_text(args): def normal_text(args):
t = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) t = get_tokenizer(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained( m = AutoModelForCausalLM.from_pretrained(
args.model_path, args.model_path,
torch_dtype=torch.float16, torch_dtype=torch.float16,
......
...@@ -30,7 +30,7 @@ from typing import List ...@@ -30,7 +30,7 @@ from typing import List
import torch import torch
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l from sglang.test.test_utils import calculate_rouge_l, is_in_ci
@dataclasses.dataclass @dataclasses.dataclass
...@@ -132,6 +132,9 @@ class TestGenerationModels(unittest.TestCase): ...@@ -132,6 +132,9 @@ class TestGenerationModels(unittest.TestCase):
) )
def test_others(self): def test_others(self):
if is_in_ci():
return
for model_case in ALL_OTHER_MODELS: for model_case in ALL_OTHER_MODELS:
if ( if (
"ONLY_RUN" in os.environ "ONLY_RUN" in os.environ
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment