Unverified Commit 02b72586 authored by 胡译文's avatar 胡译文 Committed by GitHub
Browse files

[Feat] Expose logprob options to `sgl.gen` API (#503)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent d557e9f3
...@@ -279,8 +279,8 @@ for out in state.text_iter(): ...@@ -279,8 +279,8 @@ for out in state.text_iter():
``` ```
### Tips and Implementation Details ### Tips and Implementation Details
- The `choices` argument in `sgl.gen` is implemented by computing the normalized log probabilities of all choices and selecting the one with the highest probability. - The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability.
- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. - The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`.
## Backend: SGLang Runtime (SRT) ## Backend: SGLang Runtime (SRT)
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend. The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
...@@ -337,7 +337,6 @@ response = client.chat.completions.create( ...@@ -337,7 +337,6 @@ response = client.chat.completions.create(
print(response) print(response)
``` ```
By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3. By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3.
If needed, you can also override the chat template when launching the server: If needed, you can also override the chat template when launching the server:
...@@ -384,9 +383,8 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port ...@@ -384,9 +383,8 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
- Llama - Llama
- Mistral - Mistral
- Mixtral - Mixtral
- Qwen / Qwen 2 - Qwen / Qwen 2 / Qwen 2 MoE
- Gemma - Gemma / Gemma 2
- Please add a new flag `--attention-reduce-in-fp32` to avoid some precision errors.
- `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32` - `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32`
- LLaVA - LLaVA
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
...@@ -399,6 +397,8 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port ...@@ -399,6 +397,8 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
- StableLM - StableLM
- Command-R - Command-R
- DBRX - DBRX
- Grok
- ChatGLM
- AWQ/GPTQ/Marlin quantization - AWQ/GPTQ/Marlin quantization
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md). Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).
......
from math import exp
from pprint import pformat
import sglang as sgl
YELLOW = "\033[1;33m"
GREEN = "\033[1;32m"
BLUE = "\033[1;34m"
CLEAR = "\033[1;0m"
@sgl.function
def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
"""CoT Decoding: http://arxiv.org/abs/2402.10200"""
if is_chat_model:
s += sgl.user("Question: " + question + "\nAnswer:")
s += sgl.assistant_begin()
else:
s += "Question: " + question + "\nAnswer:"
step_0 = s.fork(1)[0]
forks = s.fork(get_top_k)
answer_forks = s.fork(get_top_k)
# decoding step 0
step_0 += sgl.gen(
"get_top_k",
max_tokens=0,
return_logprob=True,
top_logprobs_num=get_top_k,
return_text_in_logprobs=True,
)
logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0]
print("Decoding step 0:",
", ".join(pformat(token[2]) for token in logprobs))
for idx, (f, token) in enumerate(zip(forks, logprobs)):
logprob, token_id, text = token
f += text
if text == "<|end_of_text|>":
print(
f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score=nan, answer=nan){CLEAR}"
)
continue
# continue greedy decoding
f += sgl.gen(
"answer",
temperature=0,
max_tokens=1024,
return_logprob=True,
top_logprobs_num=2,
return_text_in_logprobs=True,
)
# calculate probability disparity between the top and secondary tokens
x1s = [
exp(xt[0][0])
for xt in f.get_meta_info("answer")["decode_top_logprobs"]
]
x2s = [
exp(xt[1][0])
for xt in f.get_meta_info("answer")["decode_top_logprobs"]
]
tokens = [
xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]
]
delta = (sum(x1s) - sum(x2s)) / len(x1s)
# extract the answer span (without the '<|end_of_text|>' token)
answer_forks[idx] += text + f["answer"] + "\nSo the answer is"
answer_forks[idx] += sgl.gen(
"answer_span",
temperature=0,
max_tokens=64,
return_logprob=True,
top_logprobs_num=2,
return_text_in_logprobs=True,
)
answer = answer_forks[idx]['answer_span'].replace('\n', ' ').strip(':')
print(
f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score={delta}, answer={answer}){CLEAR}"
)
generated_text = str(answer_forks[idx])[len("ProgramState("):-1]
print(f"{BLUE}{pformat(generated_text)}{CLEAR}")
if verbose:
answer_tokens = [
xt[0][2] for xt in answer_forks[idx].get_meta_info(
"answer_span")["decode_top_logprobs"]
]
answer_x1s = [
exp(xt[0][0]) for xt in answer_forks[idx].get_meta_info(
"answer_span")["decode_top_logprobs"]
]
answer_x2s = [
exp(xt[1][0]) for xt in answer_forks[idx].get_meta_info(
"answer_span")["decode_top_logprobs"]
]
for token, x1, x2 in zip(tokens, x1s, x2s):
print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})",
end="")
print("\n===========")
for token, x1, x2 in zip(answer_tokens, answer_x1s, answer_x2s):
print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})",
end="")
print()
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
state = cot_decoding.run(
question=
r"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?",
get_top_k=10,
is_chat_model=True,
verbose=False,
)
...@@ -67,10 +67,16 @@ def gen( ...@@ -67,10 +67,16 @@ def gen(
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None, ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None, dtype: Optional[type] = None,
choices: Optional[List[str]] = None, choices: Optional[List[str]] = None,
regex: Optional[str] = None, regex: Optional[str] = None,
): ):
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
if choices: if choices:
return SglSelect(name, choices, 0.0 if temperature is None else temperature) return SglSelect(name, choices, 0.0 if temperature is None else temperature)
...@@ -91,6 +97,10 @@ def gen( ...@@ -91,6 +97,10 @@ def gen(
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
ignore_eos, ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
dtype, dtype,
regex, regex,
) )
...@@ -106,6 +116,10 @@ def gen_int( ...@@ -106,6 +116,10 @@ def gen_int(
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None, ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
): ):
return SglGen( return SglGen(
name, name,
...@@ -117,6 +131,10 @@ def gen_int( ...@@ -117,6 +131,10 @@ def gen_int(
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
ignore_eos, ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
int, int,
None, None,
) )
...@@ -132,6 +150,10 @@ def gen_string( ...@@ -132,6 +150,10 @@ def gen_string(
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
ignore_eos: Optional[bool] = None, ignore_eos: Optional[bool] = None,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
): ):
return SglGen( return SglGen(
name, name,
...@@ -143,6 +165,10 @@ def gen_string( ...@@ -143,6 +165,10 @@ def gen_string(
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
ignore_eos, ignore_eos,
return_logprob,
logprob_start_len,
top_logprobs_num,
return_text_in_logprobs,
str, str,
None, None,
) )
......
...@@ -12,6 +12,7 @@ from sglang.utils import http_request ...@@ -12,6 +12,7 @@ from sglang.utils import http_request
class RuntimeEndpoint(BaseBackend): class RuntimeEndpoint(BaseBackend):
def __init__( def __init__(
self, self,
base_url: str, base_url: str,
...@@ -37,8 +38,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -37,8 +38,7 @@ class RuntimeEndpoint(BaseBackend):
self.model_info = res.json() self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path( self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"] self.model_info["model_path"])
)
def get_model_name(self): def get_model_name(self):
return self.model_info["model_path"] return self.model_info["model_path"]
...@@ -124,6 +124,11 @@ class RuntimeEndpoint(BaseBackend): ...@@ -124,6 +124,11 @@ class RuntimeEndpoint(BaseBackend):
else: else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
self._add_images(s, data) self._add_images(s, data)
res = http_request( res = http_request(
...@@ -166,6 +171,11 @@ class RuntimeEndpoint(BaseBackend): ...@@ -166,6 +171,11 @@ class RuntimeEndpoint(BaseBackend):
else: else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
data["stream"] = True data["stream"] = True
self._add_images(s, data) self._add_images(s, data)
......
...@@ -668,6 +668,10 @@ class StreamExecutor: ...@@ -668,6 +668,10 @@ class StreamExecutor:
"frequency_penalty", "frequency_penalty",
"presence_penalty", "presence_penalty",
"ignore_eos", "ignore_eos",
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
"dtype", "dtype",
"regex", "regex",
]: ]:
......
...@@ -23,6 +23,10 @@ class SglSamplingParams: ...@@ -23,6 +23,10 @@ class SglSamplingParams:
frequency_penalty: float = 0.0 frequency_penalty: float = 0.0
presence_penalty: float = 0.0 presence_penalty: float = 0.0
ignore_eos: bool = False ignore_eos: bool = False
return_logprob: Optional[bool] = None
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
# for constrained generation, not included in to_xxx_kwargs # for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None dtype: Optional[str] = None
...@@ -37,6 +41,11 @@ class SglSamplingParams: ...@@ -37,6 +41,11 @@ class SglSamplingParams:
self.top_k, self.top_k,
self.frequency_penalty, self.frequency_penalty,
self.presence_penalty, self.presence_penalty,
self.ignore_eos,
self.return_logprob,
self.logprob_start_len,
self.top_logprobs_num,
self.return_text_in_logprobs,
) )
def to_openai_kwargs(self): def to_openai_kwargs(self):
...@@ -139,6 +148,10 @@ class SglFunction: ...@@ -139,6 +148,10 @@ class SglFunction:
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
ignore_eos: bool = False, ignore_eos: bool = False,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
stream: bool = False, stream: bool = False,
backend=None, backend=None,
**kwargs, **kwargs,
...@@ -154,6 +167,10 @@ class SglFunction: ...@@ -154,6 +167,10 @@ class SglFunction:
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
) )
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
return run_program(self, backend, args, kwargs, default_sampling_para, stream) return run_program(self, backend, args, kwargs, default_sampling_para, stream)
...@@ -170,6 +187,10 @@ class SglFunction: ...@@ -170,6 +187,10 @@ class SglFunction:
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
ignore_eos: bool = False, ignore_eos: bool = False,
return_logprob: Optional[bool] = None,
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
backend=None, backend=None,
num_threads: Union[str, int] = "auto", num_threads: Union[str, int] = "auto",
progress_bar: bool = False, progress_bar: bool = False,
...@@ -203,6 +224,10 @@ class SglFunction: ...@@ -203,6 +224,10 @@ class SglFunction:
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
) )
backend = backend or global_config.default_backend backend = backend or global_config.default_backend
return run_program_batch( return run_program_batch(
...@@ -350,7 +375,7 @@ class SglArgument(SglExpr): ...@@ -350,7 +375,7 @@ class SglArgument(SglExpr):
class SglImage(SglExpr): class SglImage(SglExpr):
def __init__(self, path): def __init__(self, path: str):
self.path = path self.path = path
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -358,7 +383,7 @@ class SglImage(SglExpr): ...@@ -358,7 +383,7 @@ class SglImage(SglExpr):
class SglVideo(SglExpr): class SglVideo(SglExpr):
def __init__(self, path, num_frames): def __init__(self, path: str, num_frames: int):
self.path = path self.path = path
self.num_frames = num_frames self.num_frames = num_frames
...@@ -369,18 +394,23 @@ class SglVideo(SglExpr): ...@@ -369,18 +394,23 @@ class SglVideo(SglExpr):
class SglGen(SglExpr): class SglGen(SglExpr):
def __init__( def __init__(
self, self,
name, name: Optional[str] = None,
max_new_tokens, max_new_tokens: Optional[int] = None,
stop, stop: Optional[Union[str, List[str]]] = None,
temperature, temperature: Optional[float] = None,
top_p, top_p: Optional[float] = None,
top_k, top_k: Optional[int] = None,
frequency_penalty, frequency_penalty: Optional[float] = None,
presence_penalty, presence_penalty: Optional[float] = None,
ignore_eos, ignore_eos: Optional[bool] = None,
dtype, return_logprob: Optional[bool] = None,
regex, logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
dtype: Optional[type] = None,
regex: Optional[str] = None,
): ):
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
super().__init__() super().__init__()
self.name = name self.name = name
self.sampling_params = SglSamplingParams( self.sampling_params = SglSamplingParams(
...@@ -392,6 +422,10 @@ class SglGen(SglExpr): ...@@ -392,6 +422,10 @@ class SglGen(SglExpr):
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
return_logprob=return_logprob,
logprob_start_len=logprob_start_len,
top_logprobs_num=top_logprobs_num,
return_text_in_logprobs=return_text_in_logprobs,
dtype=dtype, dtype=dtype,
regex=regex, regex=regex,
) )
...@@ -401,7 +435,7 @@ class SglGen(SglExpr): ...@@ -401,7 +435,7 @@ class SglGen(SglExpr):
class SglConstantText(SglExpr): class SglConstantText(SglExpr):
def __init__(self, value): def __init__(self, value: str):
super().__init__() super().__init__()
self.value = value self.value = value
...@@ -410,7 +444,7 @@ class SglConstantText(SglExpr): ...@@ -410,7 +444,7 @@ class SglConstantText(SglExpr):
class SglRoleBegin(SglExpr): class SglRoleBegin(SglExpr):
def __init__(self, role): def __init__(self, role: str):
super().__init__() super().__init__()
self.role = role self.role = role
...@@ -419,7 +453,7 @@ class SglRoleBegin(SglExpr): ...@@ -419,7 +453,7 @@ class SglRoleBegin(SglExpr):
class SglRoleEnd(SglExpr): class SglRoleEnd(SglExpr):
def __init__(self, role): def __init__(self, role: str):
super().__init__() super().__init__()
self.role = role self.role = role
...@@ -428,7 +462,7 @@ class SglRoleEnd(SglExpr): ...@@ -428,7 +462,7 @@ class SglRoleEnd(SglExpr):
class SglSelect(SglExpr): class SglSelect(SglExpr):
def __init__(self, name, choices, temperature): def __init__(self, name: str, choices: List[str], temperature: float):
super().__init__() super().__init__()
self.name = name self.name = name
self.choices = choices self.choices = choices
...@@ -439,7 +473,7 @@ class SglSelect(SglExpr): ...@@ -439,7 +473,7 @@ class SglSelect(SglExpr):
class SglFork(SglExpr): class SglFork(SglExpr):
def __init__(self, number, position_ids_offset=None): def __init__(self, number: int, position_ids_offset=None):
super().__init__() super().__init__()
self.number = number self.number = number
self.position_ids_offset = position_ids_offset self.position_ids_offset = position_ids_offset
...@@ -452,7 +486,7 @@ class SglFork(SglExpr): ...@@ -452,7 +486,7 @@ class SglFork(SglExpr):
class SglGetForkItem(SglExpr): class SglGetForkItem(SglExpr):
def __init__(self, index): def __init__(self, index: int):
super().__init__() super().__init__()
self.index = index self.index = index
...@@ -461,7 +495,7 @@ class SglGetForkItem(SglExpr): ...@@ -461,7 +495,7 @@ class SglGetForkItem(SglExpr):
class SglVariable(SglExpr): class SglVariable(SglExpr):
def __init__(self, name, source): def __init__(self, name: str, source):
super().__init__() super().__init__()
self.name = name self.name = name
self.source = source self.source = source
...@@ -471,7 +505,7 @@ class SglVariable(SglExpr): ...@@ -471,7 +505,7 @@ class SglVariable(SglExpr):
class SglVarScopeBegin(SglExpr): class SglVarScopeBegin(SglExpr):
def __init__(self, name): def __init__(self, name: str):
super().__init__() super().__init__()
self.name = name self.name = name
...@@ -480,7 +514,7 @@ class SglVarScopeBegin(SglExpr): ...@@ -480,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
class SglVarScopeEnd(SglExpr): class SglVarScopeEnd(SglExpr):
def __init__(self, name): def __init__(self, name: str):
super().__init__() super().__init__()
self.name = name self.name = name
...@@ -502,4 +536,4 @@ class SglCommitLazy(SglExpr): ...@@ -502,4 +536,4 @@ class SglCommitLazy(SglExpr):
super().__init__() super().__init__()
def __repr__(self): def __repr__(self):
return f"CommitLazy()" return "CommitLazy()"
...@@ -333,17 +333,18 @@ class TokenizerManager: ...@@ -333,17 +333,18 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens( ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
) )
if top_logprobs_num > 0:
ret["meta_info"][ if top_logprobs_num > 0:
"prefill_top_logprobs" ret["meta_info"][
] = self.detokenize_top_logprobs_tokens( "prefill_top_logprobs"
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs ] = self.detokenize_top_logprobs_tokens(
) ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
ret["meta_info"][ )
"decode_top_logprobs" ret["meta_info"][
] = self.detokenize_top_logprobs_tokens( "decode_top_logprobs"
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs ] = self.detokenize_top_logprobs_tokens(
) ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
return ret return ret
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text): def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
...@@ -383,7 +384,7 @@ def get_pixel_values( ...@@ -383,7 +384,7 @@ def get_pixel_values(
try: try:
processor = processor or global_processor processor = processor or global_processor
image, image_size = load_image(image_data) image, image_size = load_image(image_data)
if image_size != None: if image_size is not None:
image_hash = hash(image_data) image_hash = hash(image_data)
pixel_values = processor.image_processor(image)["pixel_values"] pixel_values = processor.image_processor(image)["pixel_values"]
for _ in range(len(pixel_values)): for _ in range(len(pixel_values)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment