Unverified Commit ad4125d1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fuse more ops & Simplify token mapping (#1758)

parent 17536e7e
......@@ -46,4 +46,8 @@ pip install nvtx
import nvtx
with nvtx.annotate("description", color="color"):
# some critical code
```
\ No newline at end of file
```
## Other tips
1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder.
......@@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode:
def update(
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
):
# Keep the signature for type checking, will be initialized during runtime
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
......@@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode:
kv_start_idx,
):
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
......@@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill:
self.update = self.update_single_wrapper
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
# Keep the signature for type checking, will be initialized during runtime
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
......@@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged,
):
bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
......@@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
self.max_context_len,
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
# extend part
if use_ragged:
......
......@@ -33,56 +33,61 @@ class Sampler(nn.Module):
if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits
# Post process logits
logits = logits.contiguous()
logits.div_(sampling_info.temperatures)
probs = torch.softmax(logits, dim=-1)
logits = None
del logits
if self.use_nan_detectioin and torch.any(torch.isnan(probs)):
logger.warning("Detected errors during sampling! NaN in the probability.")
probs = torch.where(
torch.isnan(probs), torch.full_like(probs, 1e-10), probs
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
batch_next_token_ids = torch.argmax(probs, -1)
elif global_server_args_dict["sampling_backend"] == "flashinfer":
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
batch_next_token_ids = torch.argmax(logits, -1)
else:
# Post process logits
logits.div_(sampling_info.temperatures)
probs = torch.softmax(logits, dim=-1)
logits = None
del logits
if global_server_args_dict["sampling_backend"] == "flashinfer":
max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
sampling_info.min_ps,
)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
if not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# Here we provide a slower fallback implementation.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
return batch_next_token_ids
return batch_next_token_ids.to(torch.int32)
def top_k_top_p_min_p_sampling_from_probs_torch(
......
......@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
@torch.compile(dynamic=True)
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class TpModelWorkerClient:
"""A tensor parallel model worker."""
......@@ -99,33 +108,25 @@ class TpModelWorkerClient:
# Resolve future tokens in the input
input_ids = model_worker_batch.input_ids
input_ids[:] = torch.where(
input_ids < 0,
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch
)
self.launch_event.set()
# Update the future token ids map
bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange(
-(future_token_ids_ct + bs),
-(future_token_ids_ct),
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
self.future_token_ids_map[
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
] = next_token_ids
# Copy results to the CPU
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True)
copy_event.record()
self.launch_event.set()
self.copy_queue.put((copy_event, next_token_ids))
def copy_thread_func(self):
......@@ -149,8 +150,9 @@ class TpModelWorkerClient:
# Allocate output future objects
bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange(
-(self.future_token_ids_ct + bs),
-(self.future_token_ids_ct),
-(self.future_token_ids_ct + 1),
-(self.future_token_ids_ct + 1 + bs),
-1,
dtype=torch.int32,
device=self.device,
)
......
......@@ -51,7 +51,7 @@ class ReqToTokenPool:
self.write = self.write_without_records
def write(self, indices, values):
# Keep the signature for type checking, will be initialized during runtime
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def available_size(self):
......@@ -221,16 +221,21 @@ class MHATokenToKVPool(BaseTokenToKVPool):
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if cache_v.dtype != self.dtype:
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
else:
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
copy_two_array(
loc,
self.k_buffer[layer_id],
cache_k,
self.v_buffer[layer_id],
cache_v,
self.dtype,
self.store_dtype,
)
@torch.compile(dynamic=True)
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_1[loc] = src_1.to(dtype).view(store_dtype)
dst_2[loc] = src_2.to(dtype).view(store_dtype)
class MLATokenToKVPool(BaseTokenToKVPool):
......
......@@ -92,6 +92,11 @@ def set_torch_compile_config():
torch._dynamo.config.accumulated_cache_size_limit = 1024
@torch.compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
......@@ -112,7 +117,6 @@ class CudaGraphRunner:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
]
......@@ -253,7 +257,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
positions=clamp_position(seq_lens),
)
return forward(input_ids, forward_batch.positions, forward_batch)
......
......@@ -67,6 +67,7 @@ def run_eval(args):
model=args.model,
max_tokens=2048,
base_url=base_url,
temperature=getattr(args, "temperature", 0.0),
)
# Run eval
......@@ -119,6 +120,7 @@ if __name__ == "__main__":
parser.add_argument("--eval-name", type=str, default="mmlu")
parser.add_argument("--num-examples", type=int)
parser.add_argument("--num-threads", type=int, default=512)
parser.add_argument("--temperature", type=float, default=0.0)
args = parser.parse_args()
run_eval(args)
......@@ -31,6 +31,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
......
......@@ -23,7 +23,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--sampling-backend", "pytorch"],
other_args=["--sampling-backend", "pytorch", "--disable-radix-cache"],
)
@classmethod
......@@ -37,6 +37,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment