Unverified Commit ad4125d1 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fuse more ops & Simplify token mapping (#1758)

parent 17536e7e
...@@ -46,4 +46,8 @@ pip install nvtx ...@@ -46,4 +46,8 @@ pip install nvtx
import nvtx import nvtx
with nvtx.annotate("description", color="color"): with nvtx.annotate("description", color="color"):
# some critical code # some critical code
``` ```
\ No newline at end of file
## Other tips
1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder.
...@@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -337,7 +337,7 @@ class FlashInferIndicesUpdaterDecode:
def update( def update(
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
): ):
# Keep the signature for type checking, will be initialized during runtime # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
def update_single_wrapper( def update_single_wrapper(
...@@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -432,8 +432,8 @@ class FlashInferIndicesUpdaterDecode:
kv_start_idx, kv_start_idx,
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty( kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda" paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
) )
...@@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -497,7 +497,7 @@ class FlashInferIndicesUpdaterPrefill:
self.update = self.update_single_wrapper self.update = self.update_single_wrapper
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens): def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
# Keep the signature for type checking, will be initialized during runtime # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
def update_single_wrapper( def update_single_wrapper(
...@@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -589,8 +589,8 @@ class FlashInferIndicesUpdaterPrefill:
use_ragged, use_ragged,
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
...@@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -602,8 +602,8 @@ class FlashInferIndicesUpdaterPrefill:
self.max_context_len, self.max_context_len,
) )
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
# extend part # extend part
if use_ragged: if use_ragged:
......
...@@ -33,56 +33,61 @@ class Sampler(nn.Module): ...@@ -33,56 +33,61 @@ class Sampler(nn.Module):
if isinstance(logits, LogitsProcessorOutput): if isinstance(logits, LogitsProcessorOutput):
logits = logits.next_token_logits logits = logits.next_token_logits
# Post process logits
logits = logits.contiguous() logits = logits.contiguous()
logits.div_(sampling_info.temperatures)
probs = torch.softmax(logits, dim=-1)
logits = None
del logits
if self.use_nan_detectioin and torch.any(torch.isnan(probs)): if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
logger.warning("Detected errors during sampling! NaN in the probability.") logger.warning("Detected errors during sampling! NaN in the logits.")
probs = torch.where( logits = torch.where(
torch.isnan(probs), torch.full_like(probs, 1e-10), probs torch.isnan(logits), torch.full_like(logits, -1e5), logits
) )
if sampling_info.is_all_greedy: if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling # Use torch.argmax if all requests use greedy sampling
batch_next_token_ids = torch.argmax(probs, -1) batch_next_token_ids = torch.argmax(logits, -1)
elif global_server_args_dict["sampling_backend"] == "flashinfer": else:
max_top_k_round, batch_size = 32, probs.shape[0] # Post process logits
uniform_samples = torch.rand( logits.div_(sampling_info.temperatures)
(max_top_k_round, batch_size), device=probs.device probs = torch.softmax(logits, dim=-1)
) logits = None
if sampling_info.need_min_p_sampling: del logits
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps) if global_server_args_dict["sampling_backend"] == "flashinfer":
batch_next_token_ids, success = min_p_sampling_from_probs( max_top_k_round, batch_size = 32, probs.shape[0]
probs, uniform_samples, sampling_info.min_ps uniform_samples = torch.rand(
(max_top_k_round, batch_size), device=probs.device
) )
else: if sampling_info.need_min_p_sampling:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs( probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
else:
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
probs,
uniform_samples,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
)
if not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs, probs,
uniform_samples,
sampling_info.top_ks, sampling_info.top_ks,
sampling_info.top_ps, sampling_info.top_ps,
filter_apply_order="joint", sampling_info.min_ps,
)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
) )
if not torch.all(success): return batch_next_token_ids.to(torch.int32)
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# Here we provide a slower fallback implementation.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
)
return batch_next_token_ids
def top_k_top_p_min_p_sampling_from_probs_torch( def top_k_top_p_min_p_sampling_from_probs_torch(
......
...@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs ...@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@torch.compile(dynamic=True)
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class TpModelWorkerClient: class TpModelWorkerClient:
"""A tensor parallel model worker.""" """A tensor parallel model worker."""
...@@ -99,33 +108,25 @@ class TpModelWorkerClient: ...@@ -99,33 +108,25 @@ class TpModelWorkerClient:
# Resolve future tokens in the input # Resolve future tokens in the input
input_ids = model_worker_batch.input_ids input_ids = model_worker_batch.input_ids
input_ids[:] = torch.where( resolve_future_token_ids(input_ids, self.future_token_ids_map)
input_ids < 0,
self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
# Run forward # Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation( logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
self.launch_event.set()
# Update the future token ids map # Update the future token ids map
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange( self.future_token_ids_map[
-(future_token_ids_ct + bs), future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
-(future_token_ids_ct), ] = next_token_ids
dtype=torch.int32,
device=self.device,
)
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
torch.int32
)
# Copy results to the CPU
next_token_ids = next_token_ids.to("cpu", non_blocking=True) next_token_ids = next_token_ids.to("cpu", non_blocking=True)
copy_event = torch.cuda.Event(blocking=True) copy_event = torch.cuda.Event(blocking=True)
copy_event.record() copy_event.record()
self.launch_event.set()
self.copy_queue.put((copy_event, next_token_ids)) self.copy_queue.put((copy_event, next_token_ids))
def copy_thread_func(self): def copy_thread_func(self):
...@@ -149,8 +150,9 @@ class TpModelWorkerClient: ...@@ -149,8 +150,9 @@ class TpModelWorkerClient:
# Allocate output future objects # Allocate output future objects
bs = len(model_worker_batch.seq_lens) bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange( future_next_token_ids = torch.arange(
-(self.future_token_ids_ct + bs), -(self.future_token_ids_ct + 1),
-(self.future_token_ids_ct), -(self.future_token_ids_ct + 1 + bs),
-1,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
......
...@@ -51,7 +51,7 @@ class ReqToTokenPool: ...@@ -51,7 +51,7 @@ class ReqToTokenPool:
self.write = self.write_without_records self.write = self.write_without_records
def write(self, indices, values): def write(self, indices, values):
# Keep the signature for type checking, will be initialized during runtime # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
def available_size(self): def available_size(self):
...@@ -221,16 +221,21 @@ class MHATokenToKVPool(BaseTokenToKVPool): ...@@ -221,16 +221,21 @@ class MHATokenToKVPool(BaseTokenToKVPool):
cache_v: torch.Tensor, cache_v: torch.Tensor,
): ):
layer_id = layer.layer_id layer_id = layer.layer_id
if cache_k.dtype != self.dtype: copy_two_array(
cache_k = cache_k.to(self.dtype) loc,
if cache_v.dtype != self.dtype: self.k_buffer[layer_id],
cache_v = cache_v.to(self.dtype) cache_k,
if self.store_dtype != self.dtype: self.v_buffer[layer_id],
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) cache_v,
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) self.dtype,
else: self.store_dtype,
self.k_buffer[layer_id][loc] = cache_k )
self.v_buffer[layer_id][loc] = cache_v
@torch.compile(dynamic=True)
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_1[loc] = src_1.to(dtype).view(store_dtype)
dst_2[loc] = src_2.to(dtype).view(store_dtype)
class MLATokenToKVPool(BaseTokenToKVPool): class MLATokenToKVPool(BaseTokenToKVPool):
......
...@@ -92,6 +92,11 @@ def set_torch_compile_config(): ...@@ -92,6 +92,11 @@ def set_torch_compile_config():
torch._dynamo.config.accumulated_cache_size_limit = 1024 torch._dynamo.config.accumulated_cache_size_limit = 1024
@torch.compile(dynamic=True)
def clamp_position(seq_lens):
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
class CudaGraphRunner: class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
...@@ -112,7 +117,6 @@ class CudaGraphRunner: ...@@ -112,7 +117,6 @@ class CudaGraphRunner:
self.capture_bs = list(range(1, 32)) + [64, 128] self.capture_bs = list(range(1, 32)) + [64, 128]
else: else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [ self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
] ]
...@@ -253,7 +257,7 @@ class CudaGraphRunner: ...@@ -253,7 +257,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
return_logprob=False, return_logprob=False,
top_logprobs_nums=[0] * bs, top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64), positions=clamp_position(seq_lens),
) )
return forward(input_ids, forward_batch.positions, forward_batch) return forward(input_ids, forward_batch.positions, forward_batch)
......
...@@ -67,6 +67,7 @@ def run_eval(args): ...@@ -67,6 +67,7 @@ def run_eval(args):
model=args.model, model=args.model,
max_tokens=2048, max_tokens=2048,
base_url=base_url, base_url=base_url,
temperature=getattr(args, "temperature", 0.0),
) )
# Run eval # Run eval
...@@ -119,6 +120,7 @@ if __name__ == "__main__": ...@@ -119,6 +120,7 @@ if __name__ == "__main__":
parser.add_argument("--eval-name", type=str, default="mmlu") parser.add_argument("--eval-name", type=str, default="mmlu")
parser.add_argument("--num-examples", type=int) parser.add_argument("--num-examples", type=int)
parser.add_argument("--num-threads", type=int, default=512) parser.add_argument("--num-threads", type=int, default=512)
parser.add_argument("--temperature", type=float, default=0.0)
args = parser.parse_args() args = parser.parse_args()
run_eval(args) run_eval(args)
...@@ -31,6 +31,7 @@ class TestEvalAccuracyMini(unittest.TestCase): ...@@ -31,6 +31,7 @@ class TestEvalAccuracyMini(unittest.TestCase):
eval_name="mmlu", eval_name="mmlu",
num_examples=64, num_examples=64,
num_threads=32, num_threads=32,
temperature=0.1,
) )
metrics = run_eval(args) metrics = run_eval(args)
......
...@@ -23,7 +23,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase): ...@@ -23,7 +23,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--sampling-backend", "pytorch"], other_args=["--sampling-backend", "pytorch", "--disable-radix-cache"],
) )
@classmethod @classmethod
...@@ -37,6 +37,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase): ...@@ -37,6 +37,7 @@ class TestPyTorchSamplingBackend(unittest.TestCase):
eval_name="mmlu", eval_name="mmlu",
num_examples=64, num_examples=64,
num_threads=32, num_threads=32,
temperature=0.1,
) )
metrics = run_eval(args) metrics = run_eval(args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment