"vscode:/vscode.git/clone" did not exist on "8a0434d2a39f48c1dbc52a9dcfa6a9dfa43a574f"
Unverified Commit 11616fc6 authored by sglang's avatar sglang Committed by GitHub
Browse files

Minor fix in compiler & format (#545)

parent 9ce89bc1
...@@ -38,7 +38,6 @@ def sample_requests( ...@@ -38,7 +38,6 @@ def sample_requests(
num_requests: int, num_requests: int,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
def load_dataset(): def load_dataset():
with open(dataset_path, encoding="utf-8") as f: with open(dataset_path, encoding="utf-8") as f:
dataset = json.load(f) dataset = json.load(f)
......
...@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio): ...@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
) )
for i in redirect_indices: for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines)) target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[i] = ( lines[
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." i
) ] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
redirects[i] = target_idx redirects[i] = target_idx
# Build links and find sources # Build links and find sources
......
...@@ -13,7 +13,6 @@ except ImportError as e: ...@@ -13,7 +13,6 @@ except ImportError as e:
class LiteLLM(BaseBackend): class LiteLLM(BaseBackend):
def __init__( def __init__(
self, self,
model_name, model_name,
......
...@@ -4,7 +4,7 @@ from queue import Queue ...@@ -4,7 +4,7 @@ from queue import Queue
from typing import List, Union from typing import List, Union
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
from sglang.lang.ir import ( from sglang.lang.ir import (
SglArgument, SglArgument,
SglConstantText, SglConstantText,
...@@ -184,7 +184,7 @@ class CompiledFunction: ...@@ -184,7 +184,7 @@ class CompiledFunction:
# Extract prefix by tracing and cache it # Extract prefix by tracing and cache it
if len(batch_kwargs) > 1: if len(batch_kwargs) > 1:
pin_program(self.function, backend) cache_program(self.function, backend)
# Run all programs # Run all programs
if num_threads == "auto": if num_threads == "auto":
......
...@@ -6,7 +6,6 @@ import multiprocessing as mp ...@@ -6,7 +6,6 @@ import multiprocessing as mp
from sglang.srt.server import ServerArgs, launch_server from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__": if __name__ == "__main__":
model_overide_args = {} model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2 model_overide_args["mm_spatial_pool_stride"] = 2
......
...@@ -498,9 +498,10 @@ class Batch: ...@@ -498,9 +498,10 @@ class Batch:
req.output_ids = cur_output_ids req.output_ids = cur_output_ids
continue continue
jump_forward_str, next_state = ( (
req.jump_forward_map.jump_forward_symbol(cur_state) jump_forward_str,
) next_state,
) = req.jump_forward_map.jump_forward_symbol(cur_state)
# Make the incrementally decoded text part of jump_forward_str # Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt # so that the UTF-8 will not corrupt
......
...@@ -283,13 +283,14 @@ class ModelTpServer: ...@@ -283,13 +283,14 @@ class ModelTpServer:
(recv_req.image_hash >> 64) % self.model_config.vocab_size, (recv_req.image_hash >> 64) % self.model_config.vocab_size,
] ]
req.image_size = recv_req.image_size req.image_size = recv_req.image_size
req.origin_input_ids, req.image_offset = ( (
self.model_runner.model.pad_input_ids( req.origin_input_ids,
req.origin_input_ids_unpadded, req.image_offset,
req.pad_value, ) = self.model_runner.model.pad_input_ids(
req.pixel_values.shape, req.origin_input_ids_unpadded,
req.image_size, req.pad_value,
) req.pixel_values.shape,
req.image_size,
) )
req.sampling_params = recv_req.sampling_params req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
......
...@@ -35,7 +35,6 @@ class GenerateReqInput: ...@@ -35,7 +35,6 @@ class GenerateReqInput:
stream: bool = False stream: bool = False
def post_init(self): def post_init(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
self.text is not None and self.input_ids is not None self.text is not None and self.input_ids is not None
): ):
......
...@@ -334,15 +334,15 @@ class TokenizerManager: ...@@ -334,15 +334,15 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
) )
if top_logprobs_num > 0: if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = ( ret["meta_info"][
self.detokenize_top_logprobs_tokens( "prefill_top_logprobs"
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs ] = self.detokenize_top_logprobs_tokens(
) ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
) )
ret["meta_info"]["decode_top_logprobs"] = ( ret["meta_info"][
self.detokenize_top_logprobs_tokens( "decode_top_logprobs"
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs ] = self.detokenize_top_logprobs_tokens(
) ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
) )
return ret return ret
......
...@@ -36,7 +36,6 @@ LoraConfig = None ...@@ -36,7 +36,6 @@ LoraConfig = None
class GLMAttention(nn.Module): class GLMAttention(nn.Module):
def __init__( def __init__(
self, self,
config, config,
...@@ -294,7 +293,6 @@ class GLMTransformer(nn.Module): ...@@ -294,7 +293,6 @@ class GLMTransformer(nn.Module):
class ChatGLMModel(nn.Module): class ChatGLMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config,
......
...@@ -521,7 +521,6 @@ class Grok1DecoderLayer(nn.Module): ...@@ -521,7 +521,6 @@ class Grok1DecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = ( hidden_states = (
self.post_attn_norm( self.post_attn_norm(
self.self_attn( self.self_attn(
......
...@@ -160,9 +160,9 @@ class LlamaDecoderLayer(nn.Module): ...@@ -160,9 +160,9 @@ class LlamaDecoderLayer(nn.Module):
if rope_scaling is not None and getattr( if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None config, "original_max_position_embeddings", None
): ):
rope_scaling["original_max_position_embeddings"] = ( rope_scaling[
config.original_max_position_embeddings "original_max_position_embeddings"
) ] = config.original_max_position_embeddings
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment