Commit 94694314 authored by lishen's avatar lishen
Browse files

接入ROCSHMEM的multiqp更新后的代码

parent 3872dd54
...@@ -6,6 +6,7 @@ export OMPI_MCA_coll_hcoll_enable=0 ...@@ -6,6 +6,7 @@ export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=rc,rocm export UCX_TLS=rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1 # export ROCSHMEM_UNIQUEID_WITH_MPI=1
export OMPI_MCA_rmaps_base_mapping_policy="slot:numa" export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384 export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
...@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ...@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=10737418240
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py --pressure-test
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility
...@@ -6,6 +6,7 @@ export OMPI_MCA_coll_hcoll_enable=0 ...@@ -6,6 +6,7 @@ export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=rc,rocm export UCX_TLS=rc,rocm
# export ROCSHMEM_UNIQUEID_WITH_MPI=1 # export ROCSHMEM_UNIQUEID_WITH_MPI=1
export OMPI_MCA_rmaps_base_mapping_policy="slot:numa" export OMPI_MCA_rmaps_base_mapping_policy="slot:numa"
export ROCSHMEM_GDA_NUM_QPS_DEFAULT_CTX=288
export ROCSHMEM_MAX_NUM_CONTEXTS=48 export ROCSHMEM_MAX_NUM_CONTEXTS=48
export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384 export UCX_ROCM_IPC_SIGPOOL_MAX_ELEMS=16384
export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1 export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
...@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ...@@ -15,5 +16,5 @@ export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export ROCSHMEM_HEAP_SIZE=10737418240 export ROCSHMEM_HEAP_SIZE=10737418240
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py --pressure-test
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility
...@@ -430,6 +430,12 @@ LOW_LATENCY_DISPATCH_RECV: ...@@ -430,6 +430,12 @@ LOW_LATENCY_DISPATCH_RECV:
recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx); recv_range[src_rank] = pack2<int, int64_t>(num_recv_tokens, recv_token_begin_idx);
} }
#if defined(ROCM_USE_MULTIQP)
if (sub_warp_id == 2 and lane_id == 0) {
internode::shmem_qp_quiet(num_ranks + responsible_expert_idx);
}
#endif
// no needs to reset because there is no iteration // no needs to reset because there is no iteration
if (lane_id == 0){ if (lane_id == 0){
volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); volatile int ret = __hip_atomic_fetch_add(&sync_large_warp_counters[warp_group_id], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP);
...@@ -696,10 +702,6 @@ combine(void* combined_x, ...@@ -696,10 +702,6 @@ combine(void* combined_x,
atomic_add_release_global(atomic_clean_flag, -1); atomic_add_release_global(atomic_clean_flag, -1);
} }
syncwarp(); syncwarp();
// if (num_ranks > 8){
// internode::shmem_fence();
// }
} }
// Receiving phase // Receiving phase
...@@ -728,6 +730,11 @@ LOW_LATENCY_COMBINE_RECV: ...@@ -728,6 +730,11 @@ LOW_LATENCY_COMBINE_RECV:
atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost); atomicAdd(reinterpret_cast<unsigned long long*>(combine_wait_recv_cost_stats + src_rank), wait_recv_cost);
} }
} }
#if defined(ROCM_USE_MULTIQP)
if (sub_warp_id == 2 and lane_id == 0) {
internode::shmem_qp_quiet(num_ranks + responsible_expert_idx);
}
#endif
} }
grid_barrier(global_atomic_counter, num_sms); grid_barrier(global_atomic_counter, num_sms);
......
...@@ -117,6 +117,10 @@ __device__ inline void shmem_long_atomic_add( ...@@ -117,6 +117,10 @@ __device__ inline void shmem_long_atomic_add(
} }
#if defined(ROCM_USE_MULTIQP) #if defined(ROCM_USE_MULTIQP)
__device__ inline void shmem_qp_quiet(int idx_qp) {
rocshmem::rocshmem_quiet_dp(idx_qp);
}
__device__ inline void shmemx_int8_put_nbi_warp_dp( __device__ inline void shmemx_int8_put_nbi_warp_dp(
signed char *dest, const signed char *source, size_t nelems, int qp_idx, int pe) { signed char *dest, const signed char *source, size_t nelems, int qp_idx, int pe) {
rocshmem::rocshmem_schar_put_nbi_wave_dp(dest, source, nelems, qp_idx, pe); rocshmem::rocshmem_schar_put_nbi_wave_dp(dest, source, nelems, qp_idx, pe);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment