Unverified Commit fd7e15b7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Revert "[bug fix] Ensure local token and global token buffers are pointing to...

Revert "[bug fix] Ensure local token and global token buffers are pointing to different storage " (#8993)
parent fc42ff7b
...@@ -264,10 +264,9 @@ def _dp_gather_via_all_reduce( ...@@ -264,10 +264,9 @@ def _dp_gather_via_all_reduce(
assert global_tokens.is_contiguous() assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0): if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
if local_tokens.untyped_storage() is global_tokens.untyped_storage(): assert (
# dp_gather is an in-place operation and requires input and output tensors to not be aliased. local_tokens.untyped_storage() is not global_tokens.untyped_storage()
# so we create a separate buffer if they share the same storage. ), "aliasing between global_tokens and local_tokens not allowed"
global_tokens = torch.empty_like(global_tokens)
memcpy_triton( memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
...@@ -348,10 +347,9 @@ def dp_scatter( ...@@ -348,10 +347,9 @@ def dp_scatter(
assert local_tokens.is_contiguous() assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous() assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0: if local_tokens.shape[0] > 0:
if local_tokens.untyped_storage() is global_tokens.untyped_storage(): assert (
# dp_scatter is an in-place operation and requires input and output tensors to not be aliased. local_tokens.untyped_storage() is not global_tokens.untyped_storage()
# so we create a separate buffer if they share the same storage. ), "aliasing between local_tokens and global_tokens not allowed"
local_tokens = torch.empty_like(local_tokens)
memcpy_triton( memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment