Unverified Commit 91bb69a8 authored by sky's avatar sky Committed by GitHub
Browse files

style: remove trailing whitespace (#373)


Signed-off-by: default avatarwangfakang <fakangwang@gmail.com>
parent f0d34aab
......@@ -65,7 +65,7 @@ ln -s build/lib.linux-x86_64-cpython-38/deep_ep_cpp.cpython-38-x86_64-linux-gnu.
# Run test cases
# NOTES: you may modify the `init_dist` function in `tests/utils.py`
# according to your own cluster settings, and launch into multiple nodes
# according to your own cluster settings, and launch into multiple nodes
python tests/test_intranode.py
python tests/test_internode.py
python tests/test_low_latency.py
......@@ -79,7 +79,7 @@ NVSHMEM_DIR=/path/to/installed/nvshmem python setup.py install
#### Installation environment variables
- `NVSHMEM_DIR`: the path to the NVSHMEM directory, disable all internode and low-latency features if not specified
- `NVSHMEM_DIR`: the path to the NVSHMEM directory, disable all internode and low-latency features if not specified
- `DISABLE_SM90_FEATURES`: 0 or 1, whether to disable SM90 features, it is required for SM90 devices or CUDA 11
- `TORCH_CUDA_ARCH_LIST`: the list of target architectures, e.g. `TORCH_CUDA_ARCH_LIST="9.0"`
- `DISABLE_AGGRESSIVE_PTX_INSTRS`: 0 or 1, whether to disable aggressive load/store instructions, see [Undefined-behavior PTX usage](#undefined-behavior-ptx-usage) for more details
......@@ -137,7 +137,7 @@ Buffer.set_num_sms(24)
# You may call this function at the framework initialization
def get_buffer(group: dist.ProcessGroup, hidden_bytes: int) -> Buffer:
global _buffer
# NOTES: you may also replace `get_*_config` with your auto-tuned results via all the tests
num_nvl_bytes, num_rdma_bytes = 0, 0
for config in (Buffer.get_dispatch_config(group.size()), Buffer.get_combine_config(group.size())):
......@@ -159,7 +159,7 @@ def dispatch_forward(x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_idx: torch.Tensor, topk_weights: torch.Tensor,
num_experts: int, previous_event: Optional[EventOverlap] = None) -> \
Tuple[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], torch.Tensor, torch.Tensor, List, Tuple, EventOverlap]:
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# NOTES: an optional `previous_event` means a CUDA event captured that you want to make it as a dependency
# of the dispatch kernel, it may be useful with communication-computation overlap. For more information, please
# refer to the docs of `Buffer.dispatch`
global _buffer
......
......@@ -597,7 +597,7 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional<torch::Ten
num_channels, num_recv_tokens, num_channels * num_ranks * 2,
barrier_signal_ptrs_gpu, rank, num_ranks,
comm_stream);
// Assign bias pointers
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
void* bias_ptrs[2] = {nullptr, nullptr};
......@@ -1007,7 +1007,7 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional<torch::Ten
barrier_signal_ptrs_gpu, rank, comm_stream,
config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks),
num_nvl_bytes, false, low_latency_mode);
// Assign bias pointers
auto bias_opts = std::vector<std::optional<torch::Tensor>>({bias_0, bias_1});
void* bias_ptrs[2] = {nullptr, nullptr};
......
......@@ -272,7 +272,7 @@ nvshmemi_ibgda_rma_p(int *rptr, const int value, int dst_pe, int qp_id, uint32_t
uint64_t raddr;
auto qp = ibgda_get_rc(dst_pe, qp_id);
ibgda_get_rkey(reinterpret_cast<uint64_t>(rptr), dst_pe, &raddr, &rkey, qp->dev_idx);
// Write WQEs
uint64_t base_wqe_idx = ibgda_reserve_wqe_slots(qp, 1);
void *wqe_ptrs;
......@@ -351,13 +351,13 @@ nvshmemi_ibgda_put_nbi_warp(uint64_t req_rptr, uint64_t req_lptr, size_t bytes,
auto remaining_bytes = bytes;
while (remaining_bytes > 0) {
if (lane_id == num_wqes) {
my_chunk_size = min(remaining_bytes,
ibgda_get_lkey_and_rkey(my_laddr = req_lptr,
&my_lkey,
req_rptr,
dst_pe,
&my_raddr,
&my_rkey,
my_chunk_size = min(remaining_bytes,
ibgda_get_lkey_and_rkey(my_laddr = req_lptr,
&my_lkey,
req_rptr,
dst_pe,
&my_raddr,
&my_rkey,
qp->dev_idx));
}
......@@ -464,7 +464,7 @@ __device__ __forceinline__ uint64_t nvshmemi_get_p2p_ptr(const uint64_t& ptr, co
return peer_base + (ptr - reinterpret_cast<uint64_t>(nvshmemi_device_state_d.heap_base));
}
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`.
// This is a simplified version of NVSHMEM's `ibgda_poll_cq`.
// Note that this implementation does not guarantee thread safety,
// so we must ensure that no other threads are concurrently using the same QP.
__device__ static __forceinline__ void
......
......@@ -63,7 +63,7 @@ std::pair<int, int> get_nvl_clean_meta(int hidden_int4, int num_scales, int num_
int num_nvl_recv_buffer_tokens, int num_channels, bool is_dispatch) {
// Return `int32_t` offset and to clean
EP_STATIC_ASSERT(sizeof(SourceMeta) % sizeof(int) == 0, "Invalid size of `SourceMeta`");
return {
(num_nvl_recv_buffer_tokens * get_num_bytes_per_token(hidden_int4, num_scales, num_topk_idx, num_topk_weights) * num_nvl_ranks * num_channels) / sizeof(int),
num_nvl_ranks * (2 * num_rdma_ranks + 2) * num_channels,
......@@ -150,8 +150,8 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in
reinterpret_cast<uint64_t>(rdma_recv_num_tokens_mixed.send_buffer(i)),
(NUM_MAX_NVL_PEERS + num_rdma_experts + 1) * sizeof(int),
translate_dst_rdma_rank<kLowLatencyMode>(i, nvl_rank), 0, lane_id, 0);
} else {
UNROLLED_WARP_COPY(1, lane_id, NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
} else {
UNROLLED_WARP_COPY(1, lane_id, NUM_MAX_NVL_PEERS + num_rdma_experts + 1,
rdma_recv_num_tokens_mixed.recv_buffer(rdma_rank),
rdma_recv_num_tokens_mixed.send_buffer(i),
ld_volatile_global, st_na_global);
......@@ -798,7 +798,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
int dst_slot_idx = (cached_nvl_channel_tail ++) % num_max_nvl_chunked_recv_tokens;
auto dst_shifted = nvl_channel_x.buffer() + dst_slot_idx * num_bytes_per_token;
// Copy data
// Copy data
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, num_bytes_per_token, false);
mbarrier_arrive_and_expect_tx(tma_mbarrier, num_bytes_per_token);
......@@ -936,7 +936,7 @@ dispatch(int4* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv
bool scale_aligned = (scale_bytes % 16 == 0);
auto tma_load_bytes = hidden_bytes + (scale_aligned ? scale_bytes : 0);
// Copy data
// Copy data
if (lane_id == 0) {
tma_load_1d(tma_buffer, shifted, tma_mbarrier, tma_load_bytes);
mbarrier_arrive_and_expect_tx(tma_mbarrier, tma_load_bytes);
......@@ -1129,7 +1129,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
constexpr int num_bytes_per_token = sizeof(int) * NUM_MAX_NVL_PEERS;
constexpr int num_tokens_per_batch = tma_batch_size / num_bytes_per_token;
EP_STATIC_ASSERT(num_bytes_per_token % 16 == 0, "num_bytes_per_token should be divisible by 16");
// TMA stuffs
extern __shared__ __align__(1024) uint8_t smem_tma_buffer[];
auto tma_buffer = smem_tma_buffer + warp_id * kNumTMABytesPerWarp;
......@@ -1174,7 +1174,7 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in
tma_store_fence();
__syncwarp();
if (lane_id == 0)
if (lane_id == 0)
tma_store_1d(tma_buffer, combined_nvl_head + batch_start_idx * NUM_MAX_NVL_PEERS, (batch_end_idx - batch_start_idx) * num_bytes_per_token);
tma_store_wait();
__syncwarp();
......@@ -1226,7 +1226,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
int lane_id, int hidden_int4, int num_topk,
int4* combined_row, float* combined_topk_weights,
const int4* bias_0_int4, const int4* bias_1_int4,
int num_max_recv_tokens, const GetAddrFn& get_addr_fn, const ReceiveTWFn& recv_tw_fn,
int num_max_recv_tokens, const GetAddrFn& get_addr_fn, const ReceiveTWFn& recv_tw_fn,
uint8_t* smem_ptr, uint32_t (&tma_phase)[kNumStages]) {
constexpr auto kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t);
......@@ -1251,7 +1251,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
auto tma_load_buffer = [=](const int& i, const int& j) -> int4* { return reinterpret_cast<int4*>(smem_ptr + i * kNumTMABufferBytesPerStage + j * kNumTMALoadBytes); };
auto tma_store_buffer = [=](const int& i) -> int4* { return reinterpret_cast<int4*>(smem_ptr + i * kNumTMABufferBytesPerStage + NUM_MAX_NVL_PEERS * kNumTMALoadBytes); };
auto tma_mbarrier = [=](const int& i) -> uint64_t* { return reinterpret_cast<uint64_t*>(smem_ptr + i * kNumTMABufferBytesPerStage + (NUM_MAX_NVL_PEERS + 1) * kNumTMALoadBytes); };
// Prefetch
if (lane_id < num_topk_ranks)
tma_load_1d(tma_load_buffer(0, lane_id), get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], 0), tma_mbarrier(0), kNumTMALoadBytes);
......@@ -1262,7 +1262,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
const int stage_idx = iter % kNumStages;
const int next_stage_idx = (iter + 1) % kNumStages;
// Prefetch next stage
// Prefetch next stage
if (shifted + 32 < hidden_int4) {
if (lane_id < num_topk_ranks)
tma_load_1d(tma_load_buffer(next_stage_idx, lane_id), get_addr_fn(topk_ranks[lane_id], slot_indices[lane_id], shifted + 32), tma_mbarrier(next_stage_idx), kNumTMALoadBytes);
......@@ -1312,7 +1312,7 @@ __device__ int combine_token(bool is_token_in_rank, int head_idx,
#pragma unroll
for (int j = 0; j < num_topk_ranks; ++ j)
recv_value_int4[j] = ld_nc_global(get_addr_fn(topk_ranks[j], slot_indices[j], i));
// Clean
// Reduce bias
float values[kDtypePerInt4] = {0};
......@@ -1756,7 +1756,7 @@ combine(int4* combined_x, float* combined_topk_weights,
combined_topk_weights + token_idx * num_topk,
bias_0 == nullptr ? nullptr : bias_0 + token_idx * hidden_int4,
bias_1 == nullptr ? nullptr : bias_1 + token_idx * hidden_int4,
num_max_rdma_chunked_recv_tokens, get_addr_fn, recv_tw_fn,
num_max_rdma_chunked_recv_tokens, get_addr_fn, recv_tw_fn,
nullptr, dummy_tma_phases);
}
......
......@@ -489,7 +489,7 @@ __forceinline__ __device__ void logfmt_check_amaxmin(uint8_t* meta_buffer, float
const auto& bf162_amaxmin = reinterpret_cast<__nv_bfloat162*>(&amaxmin2);
float log_amax[2], log_amin[2];
#pragma unroll
for (int i = 0; i < 2; ++ i) {
for (int i = 0; i < 2; ++ i) {
auto amax = static_cast<float>(bf162_amaxmin[i].x);
auto amin = static_cast<float>(bf162_amaxmin[i].y);
log_amax[i] = log2f_approx(amax);
......
......@@ -92,7 +92,7 @@ class Buffer:
# Synchronize NVSHMEM unique IDs
root_unique_id = None
if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode:
# Enable IBGDA
# Enable IBGDA
assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
......@@ -125,7 +125,7 @@ class Buffer:
def destroy(self):
"""
Destroy the cpp runtime and release resources.
"""
assert self.explicitly_destroy, '`explicitly_destroy` flag must be set'
......@@ -175,13 +175,13 @@ class Buffer:
size: the RDMA buffer size recommended.
"""
return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts)
def get_comm_stream(self) -> torch.Stream:
"""
Get the communication stream.
Returns:
stream: the communication stream.
stream: the communication stream.
"""
ts: torch.Stream = self.runtime.get_comm_stream()
return torch.cuda.Stream(stream_id=ts.stream_id, device_index=ts.device_index, device_type=ts.device_type)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment