Unverified Commit faba293a authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve gemma and documentations (#278)

parent 89885b31
...@@ -369,8 +369,13 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port ...@@ -369,8 +369,13 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
- Mistral - Mistral
- Mixtral - Mixtral
- Qwen / Qwen 2 - Qwen / Qwen 2
- Gemma
- Please add a new flag `--attention-reduce-in-fp32` to avoid some precision errors.
- `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32`
- LLaVA - LLaVA
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.6-34b --tokenizer-path liuhaotian/llava-v1.6-34b-tokenizer --port 3000`
- Yi-VL - Yi-VL
- see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py). - see [srt_example_yi_vl.py](examples/quick_start/srt_example_yi_vl.py).
- AWQ/GPTQ quantization - AWQ/GPTQ quantization
......
...@@ -21,7 +21,9 @@ class RuntimeEndpoint(BaseBackend): ...@@ -21,7 +21,9 @@ class RuntimeEndpoint(BaseBackend):
self.verify = verify self.verify = verify
res = http_request( res = http_request(
self.base_url + "/get_model_info", auth_token=self.auth_token, verify=self.verify self.base_url + "/get_model_info",
auth_token=self.auth_token,
verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
self.model_info = res.json() self.model_info = res.json()
...@@ -41,7 +43,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -41,7 +43,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}}, json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token, auth_token=self.auth_token,
verify=self.verify verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -50,7 +52,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -50,7 +52,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/generate", self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}}, json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token, auth_token=self.auth_token,
verify=self.verify verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -58,7 +60,10 @@ class RuntimeEndpoint(BaseBackend): ...@@ -58,7 +60,10 @@ class RuntimeEndpoint(BaseBackend):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data) self._add_images(s, data)
res = http_request( res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
...@@ -90,7 +95,10 @@ class RuntimeEndpoint(BaseBackend): ...@@ -90,7 +95,10 @@ class RuntimeEndpoint(BaseBackend):
self._add_images(s, data) self._add_images(s, data)
res = http_request( res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
verify=self.verify,
) )
obj = res.json() obj = res.json()
comp = obj["text"] comp = obj["text"]
...@@ -129,7 +137,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -129,7 +137,7 @@ class RuntimeEndpoint(BaseBackend):
json=data, json=data,
stream=True, stream=True,
auth_token=self.auth_token, auth_token=self.auth_token,
verify=self.verify verify=self.verify,
) )
pos = 0 pos = 0
...@@ -161,7 +169,10 @@ class RuntimeEndpoint(BaseBackend): ...@@ -161,7 +169,10 @@ class RuntimeEndpoint(BaseBackend):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data) self._add_images(s, data)
res = http_request( res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"] prompt_len = res.json()["meta_info"]["prompt_tokens"]
...@@ -175,7 +186,10 @@ class RuntimeEndpoint(BaseBackend): ...@@ -175,7 +186,10 @@ class RuntimeEndpoint(BaseBackend):
} }
self._add_images(s, data) self._add_images(s, data)
res = http_request( res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token, verify=self.verify self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
obj = res.json() obj = res.json()
...@@ -192,7 +206,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -192,7 +206,7 @@ class RuntimeEndpoint(BaseBackend):
self.base_url + "/concate_and_append_request", self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid}, json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token, auth_token=self.auth_token,
verify=self.verify verify=self.verify,
) )
assert res.status_code == 200 assert res.status_code == 200
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.utils import wrap_kernel_launcher
from sglang.srt.managers.router.model_runner import global_server_args from sglang.srt.managers.router.model_runner import global_server_args
from sglang.srt.utils import wrap_kernel_launcher
if global_server_args.attention_reduce_in_fp32: if global_server_args.attention_reduce_in_fp32:
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from torch import nn from torch import nn
from transformers import GemmaConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
...@@ -136,7 +136,7 @@ class GemmaAttention(nn.Module): ...@@ -136,7 +136,7 @@ class GemmaAttention(nn.Module):
class GemmaDecoderLayer(nn.Module): class GemmaDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: GemmaConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
...@@ -190,7 +190,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -190,7 +190,7 @@ class GemmaDecoderLayer(nn.Module):
class GemmaModel(nn.Module): class GemmaModel(nn.Module):
def __init__( def __init__(
self, self,
config: GemmaConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -213,12 +213,12 @@ class GemmaModel(nn.Module): ...@@ -213,12 +213,12 @@ class GemmaModel(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
skip_embed: bool = False, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if not skip_embed: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
else: else:
hidden_states = input_ids hidden_states = input_embeds
# Normalize the embedding by sqrt(hidden_size) # Normalize the embedding by sqrt(hidden_size)
hidden_states *= self.config.hidden_size**0.5 hidden_states *= self.config.hidden_size**0.5
...@@ -262,7 +262,7 @@ class GemmaForCausalLM(nn.Module): ...@@ -262,7 +262,7 @@ class GemmaForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: GemmaConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
...@@ -279,9 +279,9 @@ class GemmaForCausalLM(nn.Module): ...@@ -279,9 +279,9 @@ class GemmaForCausalLM(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
skip_embed: bool = False, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
......
...@@ -233,9 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -233,9 +233,7 @@ class LlavaLlamaForCausalLM(nn.Module):
input_ids, positions, input_metadata, input_embeds=input_embeds input_ids, positions, input_metadata, input_embeds=input_embeds
) )
elif input_metadata.forward_mode == ForwardMode.DECODE: elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model( return self.language_model(input_ids, positions, input_metadata)
input_ids, positions, input_metadata
)
def load_weights( def load_weights(
self, self,
......
...@@ -550,6 +550,7 @@ class Runtime: ...@@ -550,6 +550,7 @@ class Runtime:
tp_size: int = 1, tp_size: int = 1,
model_mode: List[str] = (), model_mode: List[str] = (),
schedule_heuristic: str = "lpm", schedule_heuristic: str = "lpm",
attention_reduce_in_fp32: bool = False,
random_seed: int = 42, random_seed: int = 42,
log_level: str = "error", log_level: str = "error",
port: Optional[int] = None, port: Optional[int] = None,
...@@ -572,6 +573,7 @@ class Runtime: ...@@ -572,6 +573,7 @@ class Runtime:
tp_size=tp_size, tp_size=tp_size,
model_mode=model_mode, model_mode=model_mode,
schedule_heuristic=schedule_heuristic, schedule_heuristic=schedule_heuristic,
attention_reduce_in_fp32=attention_reduce_in_fp32,
random_seed=random_seed, random_seed=random_seed,
log_level=log_level, log_level=log_level,
) )
......
...@@ -21,6 +21,7 @@ class ServerArgs: ...@@ -21,6 +21,7 @@ class ServerArgs:
model_mode: List[str] = () model_mode: List[str] = ()
schedule_heuristic: str = "lpm" schedule_heuristic: str = "lpm"
schedule_conservativeness: float = 1.0 schedule_conservativeness: float = 1.0
attention_reduce_in_fp32: bool = False
random_seed: int = 42 random_seed: int = 42
stream_interval: int = 8 stream_interval: int = 8
disable_log_stats: bool = False disable_log_stats: bool = False
...@@ -28,7 +29,6 @@ class ServerArgs: ...@@ -28,7 +29,6 @@ class ServerArgs:
log_level: str = "info" log_level: str = "info"
disable_regex_jump_forward: bool = False disable_regex_jump_forward: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
attention_reduce_in_fp32: bool = False
def __post_init__(self): def __post_init__(self):
if self.tokenizer_path is None: if self.tokenizer_path is None:
...@@ -157,6 +157,11 @@ class ServerArgs: ...@@ -157,6 +157,11 @@ class ServerArgs:
default=ServerArgs.random_seed, default=ServerArgs.random_seed,
help="Random seed.", help="Random seed.",
) )
parser.add_argument(
"--attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
)
parser.add_argument( parser.add_argument(
"--stream-interval", "--stream-interval",
type=int, type=int,
...@@ -190,11 +195,6 @@ class ServerArgs: ...@@ -190,11 +195,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
) )
parser.add_argument(
"--attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
...@@ -97,7 +97,9 @@ def http_request(url, json=None, stream=False, auth_token=None, verify=None): ...@@ -97,7 +97,9 @@ def http_request(url, json=None, stream=False, auth_token=None, verify=None):
"Content-Type": "application/json", "Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}", "Authentication": f"Bearer {auth_token}",
} }
return requests.post(url, json=json, stream=True, headers=headers, verify=verify) return requests.post(
url, json=json, stream=True, headers=headers, verify=verify
)
else: else:
req = urllib.request.Request(url) req = urllib.request.Request(url)
req.add_header("Content-Type", "application/json; charset=utf-8") req.add_header("Content-Type", "application/json; charset=utf-8")
......
...@@ -66,9 +66,9 @@ class BenchBatch: ...@@ -66,9 +66,9 @@ class BenchBatch:
p_idx = prefix_req_idx[i // fork_num].item() p_idx = prefix_req_idx[i // fork_num].item()
n_idx = self.req_pool_indices[i].item() n_idx = self.req_pool_indices[i].item()
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len] req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
req_to_token[n_idx, prefix_len : prefix_len + extend_len] = ( req_to_token[
self.out_cache_loc[i * extend_len : (i + 1) * extend_len] n_idx, prefix_len : prefix_len + extend_len
) ] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
def update_decode(self, predict_ids, batch_size): def update_decode(self, predict_ids, batch_size):
assert predict_ids.shape[0] == batch_size assert predict_ids.shape[0] == batch_size
...@@ -81,9 +81,9 @@ class BenchBatch: ...@@ -81,9 +81,9 @@ class BenchBatch:
self.out_cache_cont_start, self.out_cache_cont_start,
self.out_cache_cont_end, self.out_cache_cont_end,
) = self.token_to_kv_pool.alloc_contiguous(batch_size) ) = self.token_to_kv_pool.alloc_contiguous(batch_size)
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = ( self.req_to_token_pool.req_to_token[
self.out_cache_loc self.req_pool_indices, self.seq_lens
) ] = self.out_cache_loc
self.seq_lens.add_(1) self.seq_lens.add_(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment