Unverified Commit 03886917 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Disable all two stream overlap on amd (#6475)

parent 66324895
...@@ -38,11 +38,17 @@ import triton ...@@ -38,11 +38,17 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_compiler_backend from sglang.srt.utils import (
debug_timing,
get_compiler_backend,
is_cuda,
next_power_of_2,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024 GB = 1024 * 1024 * 1024
_is_cuda = is_cuda()
class ReqToTokenPool: class ReqToTokenPool:
...@@ -262,7 +268,7 @@ class MHATokenToKVPool(KVCache): ...@@ -262,7 +268,7 @@ class MHATokenToKVPool(KVCache):
self.layer_transfer_counter = None self.layer_transfer_counter = None
self.capture_mode = False self.capture_mode = False
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() self.alt_stream = self.device_module.Stream() if is_cuda else None
k_size, v_size = self.get_kv_size_bytes() k_size, v_size = self.get_kv_size_bytes()
logger.info( logger.info(
...@@ -392,7 +398,7 @@ class MHATokenToKVPool(KVCache): ...@@ -392,7 +398,7 @@ class MHATokenToKVPool(KVCache):
cache_k = cache_k.view(self.store_dtype) cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype)
if self.capture_mode and cache_k.shape[0] < 4: if self.capture_mode and self.alt_stream is not None:
# Overlap the copy of K and V cache for small batch size # Overlap the copy of K and V cache for small batch size
current_stream = self.device_module.current_stream() current_stream = self.device_module.current_stream()
self.alt_stream.wait_stream(current_stream) self.alt_stream.wait_stream(current_stream)
......
...@@ -76,13 +76,12 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -76,13 +76,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.expert_distribution import ( from sglang.srt.managers.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder, get_global_expert_distribution_recorder,
) )
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.operations import execute_operations from sglang.srt.operations import execute_operations
from sglang.srt.operations_strategy import compute_layer_operations from sglang.srt.operations_strategy import compute_layer_operations
...@@ -1321,8 +1320,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1321,8 +1320,7 @@ class DeepseekV2Model(nn.Module):
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"], enable_tp=not global_server_args_dict["enable_dp_attention"],
) )
# TODO(haishaw): multi-stream performance on ROCm self.alt_stream = torch.cuda.Stream() if _is_cuda else None
self.alt_stream = None if _is_hip else torch.cuda.Stream()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
DeepseekV2DecoderLayer( DeepseekV2DecoderLayer(
......
...@@ -52,7 +52,15 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -52,7 +52,15 @@ from sglang.srt.model_executor.forward_batch_info import (
PPProxyTensors, PPProxyTensors,
) )
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers from sglang.srt.utils import (
add_prefix,
fast_topk,
get_compiler_backend,
is_cuda,
make_layers,
)
_is_cuda = is_cuda()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -131,7 +139,7 @@ class Llama4MoE(nn.Module): ...@@ -131,7 +139,7 @@ class Llama4MoE(nn.Module):
return out_aD return out_aD
def _forward_core(self, hidden_states, forward_mode: ForwardMode): def _forward_core(self, hidden_states, forward_mode: ForwardMode):
if hidden_states.shape[0] < 4: if hidden_states.shape[0] < 4 and _is_cuda:
return self._forward_core_shared_routed_overlap(hidden_states) return self._forward_core_shared_routed_overlap(hidden_states)
else: else:
return self._forward_core_normal(hidden_states) return self._forward_core_normal(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment