Commit 61dc5d91 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev-d2dpcie' into v0.9.2-dev

parents a5a9263e d65c5085
...@@ -1056,9 +1056,21 @@ class CustomAllreduce { ...@@ -1056,9 +1056,21 @@ class CustomAllreduce {
size /= d; size /= d;
auto bytes = size * sizeof(typename packed_t<T>::P); auto bytes = size * sizeof(typename packed_t<T>::P);
int blocks = std::min(block_limit, (size + threads - 1) / threads); int blocks = std::min(block_limit, (size + threads - 1) / threads);
// #define KL(ngpus, name) \
// name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \
// rank_, size, dev_curr_hdp_reg, world_size_) ;
#define KL(ngpus, name) \ #define KL(ngpus, name) \
name<T, ngpus><<<blocks, threads, 0, stream>>>(ptrs, sg_, self_sg_, output, \ { \
rank_, size, dev_curr_hdp_reg, world_size_) ; void* kernelArgs[] = { \
&ptrs, &sg_, &self_sg_, &output, &rank_, &size \
}; \
hipExtLaunchKernel( \
(void*)name<T, ngpus>, \
blocks, threads, \
kernelArgs, 0, \
stream, nullptr, stopEvent, 0 \
); \
}
#define REDUCE_CASE(ngpus) \ #define REDUCE_CASE(ngpus) \
case ngpus: { \ case ngpus: { \
...@@ -1066,7 +1078,7 @@ class CustomAllreduce { ...@@ -1066,7 +1078,7 @@ class CustomAllreduce {
KL(ngpus, cross_device_reduce_1stage_pcie); \ KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \ } else { \
if ((world_size_ <= 4 && bytes < 128 * 8192) || \ if ((world_size_ <= 4 && bytes < 128 * 8192) || \
(world_size_ <= 8 && bytes < 8 * 8192)) { \ (world_size_ <= 8 && bytes < 8 * 8192)) { \
KL(ngpus, cross_device_reduce_1stage_pcie); \ KL(ngpus, cross_device_reduce_1stage_pcie); \
} else { \ } else { \
KL(ngpus, cross_device_reduce_2stage_pcie); \ KL(ngpus, cross_device_reduce_2stage_pcie); \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment