Unverified Commit 8e2ac2e6 authored by Makcum888e's avatar Makcum888e Committed by GitHub
Browse files

[NPU] fix pp_size>1 (#12195)

parent 17a57fd8
...@@ -68,7 +68,7 @@ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM) ...@@ -68,7 +68,7 @@ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
@dataclass @dataclass
class GraphCaptureContext: class GraphCaptureContext:
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream stream: torch.get_device_module().Stream
@dataclass @dataclass
...@@ -498,7 +498,7 @@ class GroupCoordinator: ...@@ -498,7 +498,7 @@ class GroupCoordinator:
maybe_pynccl_context = nullcontext() maybe_pynccl_context = nullcontext()
else: else:
maybe_pynccl_context = pynccl_comm.change_state( maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream() enable=True, stream=torch.get_device_module().current_stream()
) )
pymscclpp_comm = self.pymscclpp_comm pymscclpp_comm = self.pymscclpp_comm
...@@ -555,7 +555,7 @@ class GroupCoordinator: ...@@ -555,7 +555,7 @@ class GroupCoordinator:
and input_.symmetric_memory and input_.symmetric_memory
): ):
with self.pynccl_comm.change_state( with self.pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream() enable=True, stream=torch.get_device_module().current_stream()
): ):
self.pynccl_comm.all_reduce(input_) self.pynccl_comm.all_reduce(input_)
return input_ return input_
...@@ -655,7 +655,9 @@ class GroupCoordinator: ...@@ -655,7 +655,9 @@ class GroupCoordinator:
world_size = self.world_size world_size = self.world_size
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()): with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
):
assert ( assert (
pynccl_comm is not None and not pynccl_comm.disabled pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for reduce_scatterv" ), "pynccl is required for reduce_scatterv"
...@@ -779,7 +781,9 @@ class GroupCoordinator: ...@@ -779,7 +781,9 @@ class GroupCoordinator:
world_size = self.world_size world_size = self.world_size
pynccl_comm = self.pynccl_comm pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()): with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
):
assert ( assert (
pynccl_comm is not None and not pynccl_comm.disabled pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv" ), "pynccl is required for all_gatherv"
......
...@@ -1137,10 +1137,10 @@ class AscendTokenToKVPool(MHATokenToKVPool): ...@@ -1137,10 +1137,10 @@ class AscendTokenToKVPool(MHATokenToKVPool):
torch_npu._npu_reshape_and_cache( torch_npu._npu_reshape_and_cache(
key=cache_k, key=cache_k,
value=cache_v, value=cache_v,
key_cache=self.k_buffer[layer_id].view( key_cache=self.k_buffer[layer_id - self.start_layer].view(
-1, self.page_size, self.head_num, self.head_dim -1, self.page_size, self.head_num, self.head_dim
), ),
value_cache=self.v_buffer[layer_id].view( value_cache=self.v_buffer[layer_id - self.start_layer].view(
-1, self.page_size, self.head_num, self.head_dim -1, self.page_size, self.head_num, self.head_dim
), ),
slot_indices=loc, slot_indices=loc,
......
...@@ -1659,9 +1659,11 @@ class ModelRunner: ...@@ -1659,9 +1659,11 @@ class ModelRunner:
get_attention_tp_size() get_attention_tp_size()
), ),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.num_effective_layers,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
) )
elif self.use_mla_backend and is_nsa_model: elif self.use_mla_backend and is_nsa_model:
self.token_to_kv_pool = NSATokenToKVPool( self.token_to_kv_pool = NSATokenToKVPool(
......
...@@ -1239,42 +1239,34 @@ def point_to_point_pyobj( ...@@ -1239,42 +1239,34 @@ def point_to_point_pyobj(
dst: int = 1, dst: int = 1,
): ):
"""Send data from src to dst in group using DeviceToDevice communication.""" """Send data from src to dst in group using DeviceToDevice communication."""
device = torch.get_device_module().current_device()
if rank == src: if rank == src:
if len(data) == 0: if len(data) == 0:
tensor_size = torch.tensor( tensor_size = torch.tensor([0], dtype=torch.long, device=device)
[0], dtype=torch.long, device=torch.cuda.current_device()
)
dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_size, dst=dst, group=group)
else: else:
serialized_data = pickle.dumps(data) serialized_data = pickle.dumps(data)
size = len(serialized_data) size = len(serialized_data)
tensor_data = torch.ByteTensor( tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8) np.frombuffer(serialized_data, dtype=np.uint8)
).cuda( ).to(
device=torch.cuda.current_device() device=device
) # Move to GPU ) # Move to GPU
tensor_size = torch.tensor( tensor_size = torch.tensor([size], dtype=torch.long, device=device)
[size], dtype=torch.long, device=torch.cuda.current_device()
)
dist.send(tensor_size, dst=dst, group=group) dist.send(tensor_size, dst=dst, group=group)
dist.send(tensor_data, dst=dst, group=group) dist.send(tensor_data, dst=dst, group=group)
return data return data
elif rank == dst: elif rank == dst:
tensor_size = torch.tensor( tensor_size = torch.tensor([0], dtype=torch.long, device=device)
[0], dtype=torch.long, device=torch.cuda.current_device()
)
dist.recv(tensor_size, src=src, group=group) dist.recv(tensor_size, src=src, group=group)
size = tensor_size.item() size = tensor_size.item()
if size == 0: if size == 0:
return [] return []
tensor_data = torch.empty( tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
size, dtype=torch.uint8, device=torch.cuda.current_device()
)
dist.recv(tensor_data, src=src, group=group) dist.recv(tensor_data, src=src, group=group)
serialized_data = bytes( serialized_data = bytes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment