Unverified Commit 8e2ac2e6 authored by Makcum888e's avatar Makcum888e Committed by GitHub
Browse files

[NPU] fix pp_size>1 (#12195)

parent 17a57fd8
......@@ -68,7 +68,7 @@ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream if not _is_npu else torch.npu.Stream
stream: torch.get_device_module().Stream
@dataclass
......@@ -498,7 +498,7 @@ class GroupCoordinator:
maybe_pynccl_context = nullcontext()
else:
maybe_pynccl_context = pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
enable=True, stream=torch.get_device_module().current_stream()
)
pymscclpp_comm = self.pymscclpp_comm
......@@ -555,7 +555,7 @@ class GroupCoordinator:
and input_.symmetric_memory
):
with self.pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream()
enable=True, stream=torch.get_device_module().current_stream()
):
self.pynccl_comm.all_reduce(input_)
return input_
......@@ -655,7 +655,9 @@ class GroupCoordinator:
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for reduce_scatterv"
......@@ -779,7 +781,9 @@ class GroupCoordinator:
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
with pynccl_comm.change_state(
enable=True, stream=torch.get_device_module().current_stream()
):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv"
......
......@@ -1137,10 +1137,10 @@ class AscendTokenToKVPool(MHATokenToKVPool):
torch_npu._npu_reshape_and_cache(
key=cache_k,
value=cache_v,
key_cache=self.k_buffer[layer_id].view(
key_cache=self.k_buffer[layer_id - self.start_layer].view(
-1, self.page_size, self.head_num, self.head_dim
),
value_cache=self.v_buffer[layer_id].view(
value_cache=self.v_buffer[layer_id - self.start_layer].view(
-1, self.page_size, self.head_num, self.head_dim
),
slot_indices=loc,
......
......@@ -1659,9 +1659,11 @@ class ModelRunner:
get_attention_tp_size()
),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
elif self.use_mla_backend and is_nsa_model:
self.token_to_kv_pool = NSATokenToKVPool(
......
......@@ -1239,42 +1239,34 @@ def point_to_point_pyobj(
dst: int = 1,
):
"""Send data from src to dst in group using DeviceToDevice communication."""
device = torch.get_device_module().current_device()
if rank == src:
if len(data) == 0:
tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
dist.send(tensor_size, dst=dst, group=group)
else:
serialized_data = pickle.dumps(data)
size = len(serialized_data)
tensor_data = torch.ByteTensor(
np.frombuffer(serialized_data, dtype=np.uint8)
).cuda(
device=torch.cuda.current_device()
).to(
device=device
) # Move to GPU
tensor_size = torch.tensor(
[size], dtype=torch.long, device=torch.cuda.current_device()
)
tensor_size = torch.tensor([size], dtype=torch.long, device=device)
dist.send(tensor_size, dst=dst, group=group)
dist.send(tensor_data, dst=dst, group=group)
return data
elif rank == dst:
tensor_size = torch.tensor(
[0], dtype=torch.long, device=torch.cuda.current_device()
)
tensor_size = torch.tensor([0], dtype=torch.long, device=device)
dist.recv(tensor_size, src=src, group=group)
size = tensor_size.item()
if size == 0:
return []
tensor_data = torch.empty(
size, dtype=torch.uint8, device=torch.cuda.current_device()
)
tensor_data = torch.empty(size, dtype=torch.uint8, device=device)
dist.recv(tensor_data, src=src, group=group)
serialized_data = bytes(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment