Unverified Commit 03886917 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Disable all two stream overlap on amd (#6475)

parent 66324895
......@@ -38,11 +38,17 @@ import triton
import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_compiler_backend
from sglang.srt.utils import (
debug_timing,
get_compiler_backend,
is_cuda,
next_power_of_2,
)
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
_is_cuda = is_cuda()
class ReqToTokenPool:
......@@ -262,7 +268,7 @@ class MHATokenToKVPool(KVCache):
self.layer_transfer_counter = None
self.capture_mode = False
self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream()
self.alt_stream = self.device_module.Stream() if is_cuda else None
k_size, v_size = self.get_kv_size_bytes()
logger.info(
......@@ -392,7 +398,7 @@ class MHATokenToKVPool(KVCache):
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
if self.capture_mode and cache_k.shape[0] < 4:
if self.capture_mode and self.alt_stream is not None:
# Overlap the copy of K and V cache for small batch size
current_stream = self.device_module.current_stream()
self.alt_stream.wait_stream(current_stream)
......
......@@ -76,13 +76,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.operations import execute_operations
from sglang.srt.operations_strategy import compute_layer_operations
......@@ -1321,8 +1320,7 @@ class DeepseekV2Model(nn.Module):
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
)
# TODO(haishaw): multi-stream performance on ROCm
self.alt_stream = None if _is_hip else torch.cuda.Stream()
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.layers = nn.ModuleList(
[
DeepseekV2DecoderLayer(
......
......@@ -52,7 +52,15 @@ from sglang.srt.model_executor.forward_batch_info import (
PPProxyTensors,
)
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
from sglang.srt.utils import (
add_prefix,
fast_topk,
get_compiler_backend,
is_cuda,
make_layers,
)
_is_cuda = is_cuda()
logger = logging.getLogger(__name__)
......@@ -131,7 +139,7 @@ class Llama4MoE(nn.Module):
return out_aD
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
if hidden_states.shape[0] < 4:
if hidden_states.shape[0] < 4 and _is_cuda:
return self._forward_core_shared_routed_overlap(hidden_states)
else:
return self._forward_core_normal(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment