Unverified Commit 11616fc6 authored by sglang's avatar sglang Committed by GitHub
Browse files

Minor fix in compiler & format (#545)

parent 9ce89bc1
......@@ -38,7 +38,6 @@ def sample_requests(
num_requests: int,
tokenizer: AutoTokenizer,
) -> List[Tuple[str, int, int]]:
def load_dataset():
with open(dataset_path, encoding="utf-8") as f:
dataset = json.load(f)
......
......@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
)
for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[i] = (
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
)
lines[
i
] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
redirects[i] = target_idx
# Build links and find sources
......
......@@ -13,7 +13,6 @@ except ImportError as e:
class LiteLLM(BaseBackend):
def __init__(
self,
model_name,
......
......@@ -4,7 +4,7 @@ from queue import Queue
from typing import List, Union
from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import (
SglArgument,
SglConstantText,
......@@ -184,7 +184,7 @@ class CompiledFunction:
# Extract prefix by tracing and cache it
if len(batch_kwargs) > 1:
pin_program(self.function, backend)
cache_program(self.function, backend)
# Run all programs
if num_threads == "auto":
......
......@@ -6,7 +6,6 @@ import multiprocessing as mp
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2
......
......@@ -498,9 +498,10 @@ class Batch:
req.output_ids = cur_output_ids
continue
jump_forward_str, next_state = (
req.jump_forward_map.jump_forward_symbol(cur_state)
)
(
jump_forward_str,
next_state,
) = req.jump_forward_map.jump_forward_symbol(cur_state)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
......
......@@ -283,13 +283,14 @@ class ModelTpServer:
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
req.origin_input_ids, req.image_offset = (
self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
(
req.origin_input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
......
......@@ -35,7 +35,6 @@ class GenerateReqInput:
stream: bool = False
def post_init(self):
if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None
):
......
......@@ -334,15 +334,15 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"][
"prefill_top_logprobs"
] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"][
"decode_top_logprobs"
] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
return ret
......
......@@ -36,7 +36,6 @@ LoraConfig = None
class GLMAttention(nn.Module):
def __init__(
self,
config,
......@@ -294,7 +293,6 @@ class GLMTransformer(nn.Module):
class ChatGLMModel(nn.Module):
def __init__(
self,
config,
......
......@@ -521,7 +521,6 @@ class Grok1DecoderLayer(nn.Module):
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = (
self.post_attn_norm(
self.self_attn(
......
......@@ -160,9 +160,9 @@ class LlamaDecoderLayer(nn.Module):
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
rope_scaling[
"original_max_position_embeddings"
] = config.original_max_position_embeddings
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment