Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
25e1816e
"scripts/vscode:/vscode.git/clone" did not exist on "4411e788b4588d17c1098a227567d6ac0b6469a4"
Unverified
Commit
25e1816e
authored
Mar 17, 2025
by
Yi Zhang
Committed by
GitHub
Mar 16, 2025
Browse files
fix custom allreduce performance/accuracy problem (#4477)
parent
a53fe428
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
20 deletions
+7
-20
sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
+6
-19
sgl-kernel/include/trt_reduce_internal.cuh
sgl-kernel/include/trt_reduce_internal.cuh
+1
-1
No files found.
sgl-kernel/csrc/allreduce/trt_reduce_internal.cu
View file @
25e1816e
...
...
@@ -182,8 +182,9 @@ __inline__ __device__ void block_barrier(
}
}
}
if
constexpr
(
start
||
need_fence
)
{
__syncthreads
();
}
}
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
=
true
>
...
...
@@ -262,6 +263,8 @@ static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduc
// Store to the destination buffer.
*
reinterpret_cast
<
int4
*>
(
&
reinterpret_cast
<
T
*>
(
params
.
local_output_buffer_ptr
)[
iter_offset
])
=
sums
.
packed
;
}
block_barrier
<
false
>
(
params
.
peer_barrier_ptrs_out
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
,
grid_size
);
}
template
<
typename
T
,
int
RANKS_PER_NODE
,
bool
COPY_INPUT
=
true
>
...
...
@@ -437,24 +440,8 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
assert
(
params
.
elts_total
%
(
elts_per_thread
*
params
.
ranks_per_node
)
==
0
);
size_t
const
total_threads
=
roundUp
(
params
.
elts_total
/
(
elts_per_thread
*
params
.
ranks_per_node
),
WARP_SIZE
);
/*
threads_per_block
=
std
::
min
(
DEFAULT_BLOCK_SIZE
,
total_threads
);
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
*/
while
(
total_threads
%
blocks_per_grid
!=
0
||
total_threads
/
blocks_per_grid
>
DEFAULT_BLOCK_SIZE
)
{
blocks_per_grid
+=
1
;
}
threads_per_block
=
total_threads
/
blocks_per_grid
;
// NOTE: need to adjust here
if
(
blocks_per_grid
>
MAX_ALL_REDUCE_BLOCKS
)
{
size_t
iter_factor
=
1
;
while
(
blocks_per_grid
/
iter_factor
>
MAX_ALL_REDUCE_BLOCKS
||
blocks_per_grid
%
iter_factor
)
{
iter_factor
+=
1
;
}
blocks_per_grid
/=
iter_factor
;
}
blocks_per_grid
=
std
::
min
(
static_cast
<
int
>
(
MAX_ALL_REDUCE_BLOCKS
),
divUp
(
total_threads
,
threads_per_block
));
params
.
elts_per_rank
=
params
.
elts_total
/
params
.
ranks_per_node
;
params
.
rank_offset
=
params
.
local_rank
*
params
.
elts_per_rank
;
params
.
elts_per_block
=
roundUp
(
divUp
(
params
.
elts_per_rank
,
blocks_per_grid
),
elts_per_thread
);
...
...
sgl-kernel/include/trt_reduce_internal.cuh
View file @
25e1816e
...
...
@@ -39,7 +39,7 @@ limitations under the License.
namespace
trt_llm
{
constexpr
size_t
WARP_SIZE
=
32
;
constexpr
size_t
MAX_ALL_REDUCE_BLOCKS
=
3
6
;
constexpr
size_t
MAX_ALL_REDUCE_BLOCKS
=
3
2
;
constexpr
size_t
MAX_RANKS_PER_NODE
=
8
;
constexpr
size_t
DEFAULT_BLOCK_SIZE
=
512
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment