Unverified Commit 3a408158 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Added missing thread offsets and other information to reduce. (#646)

parent b060c9f7
...@@ -42,7 +42,8 @@ struct AllReduce { ...@@ -42,7 +42,8 @@ struct AllReduce {
if constexpr (offset == scale) { if constexpr (offset == scale) {
return x; return x;
} else { } else {
return AllReduce<Reducer, offset, scale>::run(x, red_buf); return AllReduce<Reducer, offset, scale, thread_offset, all_threads>::run(
x, red_buf);
} }
} }
...@@ -51,7 +52,7 @@ struct AllReduce { ...@@ -51,7 +52,7 @@ struct AllReduce {
constexpr int offset = threads / 2; constexpr int offset = threads / 2;
if constexpr (offset >= 32) { if constexpr (offset >= 32) {
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads)); asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads));
red_buf[threadIdx.x] = x; red_buf[threadIdx.x - thread_offset] = x;
// TODO(lei): maybe we can merge the two bar.sync into one? // TODO(lei): maybe we can merge the two bar.sync into one?
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads)); asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]); x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
...@@ -61,8 +62,8 @@ struct AllReduce { ...@@ -61,8 +62,8 @@ struct AllReduce {
if constexpr (offset == scale) { if constexpr (offset == scale) {
return x; return x;
} else { } else {
return AllReduce<Reducer, offset, scale, all_threads>::run_hopper( return AllReduce<Reducer, offset, scale, thread_offset,
x, red_buf); all_threads>::run_hopper(x, red_buf);
} }
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment