Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e04d3f28
Unverified
Commit
e04d3f28
authored
Dec 15, 2024
by
yizhang2077
Committed by
GitHub
Dec 15, 2024
Browse files
adapt tensorrt llm custom all reduce to sgl-kernel (#2481)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
5f2595be
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
872 additions
and
32 deletions
+872
-32
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+47
-19
sgl-kernel/Makefile
sgl-kernel/Makefile
+2
-2
sgl-kernel/pyproject.toml
sgl-kernel/pyproject.toml
+1
-1
sgl-kernel/setup.py
sgl-kernel/setup.py
+25
-1
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+7
-2
sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc
sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc
+13
-0
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
+282
-0
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
+91
-0
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
+102
-0
sgl-kernel/src/sgl-kernel/csrc/utils.hpp
sgl-kernel/src/sgl-kernel/csrc/utils.hpp
+36
-0
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
+3
-7
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+15
-0
sgl-kernel/tests/test_trt_reduce.py
sgl-kernel/tests/test_trt_reduce.py
+248
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
e04d3f28
cmake_minimum_required
(
VERSION 3.18
)
project
(
sgl-kernel LANGUAGES CXX CUDA
)
# Basic settings
set
(
CMAKE_EXPORT_COMPILE_COMMANDS ON
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
find_package
(
PythonInterp 3 REQUIRED
)
find_package
(
PythonLibs 3 REQUIRED
)
# Set CUDA architectures
set
(
CMAKE_CUDA_ARCHITECTURES
"75;80;86;89;90"
)
message
(
STATUS
"Building for CUDA architectures:
${
CMAKE_CUDA_ARCHITECTURES
}
"
)
find_package
(
Python3 COMPONENTS Interpreter Development REQUIRED
)
# Find PyTorch
execute_process
(
COMMAND
${
P
YTHON
_EXECUTABLE
}
-c
"import torch; print(torch.utils.cmake_prefix_path)"
COMMAND
${
P
ython3
_EXECUTABLE
}
-c
"import torch; print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE TORCH_CMAKE_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message
(
STATUS
"PYTHON_EXECUTABLE:
${
PYTHON_EXECUTABLE
}
"
)
message
(
STATUS
"TORCH_CMAKE_PATH:
${
TORCH_CMAKE_PATH
}
"
)
list
(
APPEND CMAKE_PREFIX_PATH
"
${
TORCH_CMAKE_PATH
}
"
)
find_package
(
Torch REQUIRED
)
include_directories
(
${
PYTHON_INCLUDE_DIRS
}
)
# Warp Reduce library
add_library
(
warp_reduce SHARED
src/sgl-kernel/csrc/warp_reduce.cc
src/sgl-kernel/csrc/warp_reduce_kernel.cu
)
target_include_directories
(
warp_reduce PRIVATE
${
CUDA_INCLUDE_DIRS
}
${
TORCH_INCLUDE_DIRS
}
target_include_directories
(
warp_reduce
PRIVATE
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/sgl-kernel/csrc
${
CUDA_INCLUDE_DIRS
}
${
TORCH_INCLUDE_DIRS
}
)
target_link_libraries
(
warp_reduce PRIVATE
${
TORCH_LIBRARIES
}
${
PYTHON_LIBRARIES
}
target_link_libraries
(
warp_reduce
PRIVATE
${
TORCH_LIBRARIES
}
Python3::Python
)
set_target_properties
(
warp_reduce PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
# TRT Reduce library
add_library
(
trt_reduce SHARED
src/sgl-kernel/csrc/trt_reduce.cc
src/sgl-kernel/csrc/trt_reduce_internal.cu
src/sgl-kernel/csrc/trt_reduce_kernel.cu
)
target_include_directories
(
trt_reduce
PRIVATE
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/sgl-kernel/csrc
${
CUDA_INCLUDE_DIRS
}
${
TORCH_INCLUDE_DIRS
}
)
target_link_libraries
(
trt_reduce
PRIVATE
${
TORCH_LIBRARIES
}
Python3::Python
)
# Set common properties for both libraries
foreach
(
target warp_reduce trt_reduce
)
set_target_properties
(
${
target
}
PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
POSITION_INDEPENDENT_CODE ON
CUDA_RESOLVE_DEVICE_SYMBOLS ON
PREFIX
""
SUFFIX
".so"
)
endforeach
()
sgl-kernel/Makefile
View file @
e04d3f28
...
...
@@ -10,7 +10,7 @@ install:
@
pip
install
-e
.
build
:
@
python3 setup.py bdist_wheel
@
export
MAX_JOBS
=
$(nproc)
&&
python3 setup.py bdist_wheel
clean
:
@
rm
-rf
build dist
*
.egg-info
...
...
@@ -19,4 +19,4 @@ test:
@
pytest tests/
format
:
@
find src tests
-name
'*.cc'
-o
-name
'*.cu'
-o
-name
'*.cuh'
-o
-name
'*.h'
| xargs clang-format
-i
&&
find src tests
-name
'*.py'
| xargs isort
&&
find src tests
-name
'*.py'
| xargs black
@
find src tests
-name
'*.cc'
-o
-name
'*.cu'
-o
-name
'*.cuh'
-o
-name
'*.h'
-o
-name
'*.hpp'
| xargs clang-format
-i
&&
find src tests
-name
'*.py'
| xargs isort
&&
find src tests
-name
'*.py'
| xargs black
sgl-kernel/pyproject.toml
View file @
e04d3f28
...
...
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name
=
"sgl-kernel"
version
=
"0.0.2.post
4
"
version
=
"0.0.2.post
5
"
description
=
"Kernel Library for SGLang"
readme
=
"README.md"
requires-python
=
">=3.8"
...
...
sgl-kernel/setup.py
View file @
e04d3f28
...
...
@@ -84,7 +84,31 @@ setup(
},
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
],
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
],
)
),
CUDAExtension
(
"sgl_kernel.ops.custom_reduce_cuda"
,
[
"src/sgl-kernel/csrc/trt_reduce_internal.cu"
,
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/trt_reduce.cc"
,
],
extra_compile_args
=
{
"nvcc"
:
[
"-O3"
,
"-Xcompiler"
,
"-fPIC"
,
"-gencode=arch=compute_75,code=sm_75"
,
"-gencode=arch=compute_80,code=sm_80"
,
"-gencode=arch=compute_89,code=sm_89"
,
"-gencode=arch=compute_90,code=sm_90"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
],
"cxx"
:
[
"-O3"
],
},
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
],
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
],
),
],
cmdclass
=
{
"build_ext"
:
BuildExtension
},
install_requires
=
[
"torch"
],
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
e04d3f28
from
.ops
import
warp_reduce
from
.ops
import
custom_dispose
,
custom_reduce
,
init_custom_reduce
,
warp_reduce
__all__
=
[
"warp_reduce"
]
__all__
=
[
"warp_reduce"
,
"init_custom_reduce"
,
"custom_dispose"
,
"custom_reduce"
,
]
sgl-kernel/src/sgl-kernel/csrc/trt_reduce.cc
0 → 100644
View file @
e04d3f28
#include <torch/extension.h>
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
);
void
dispose
(
fptr_t
_fa
);
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
m
.
def
(
"dispose"
,
&
dispose
,
"dispose custom allreduce meta"
);
m
.
def
(
"all_reduce"
,
&
all_reduce
,
"custom all reduce (CUDA)"
);
}
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu
0 → 100644
View file @
e04d3f28
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <tuple>
#include "trt_reduce_internal.cuh"
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
void
st_flag_release
(
uint32_t
const
&
flag
,
uint32_t
*
flag_addr
)
{
asm
volatile
(
"st.global.release.sys.b32 [%1], %0;"
::
"r"
(
flag
),
"l"
(
flag_addr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
ld_flag_acquire
(
uint32_t
*
flag_addr
)
{
uint32_t
flag
;
asm
volatile
(
"ld.global.acquire.sys.b32 %0, [%1];"
:
"=r"
(
flag
)
:
"l"
(
flag_addr
));
return
flag
;
}
namespace
trt_llm
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// Type Converter that packs data format to 128 bits data type
//
using
PackedFloat
=
union
{
int4
packed
;
float
unpacked
[
4
];
};
using
PackedHalf
=
union
{
int4
packed
;
half2
unpacked
[
4
];
};
template
<
typename
T
>
struct
PackedOn16Bytes
{};
template
<
>
struct
PackedOn16Bytes
<
float
>
{
using
Type
=
PackedFloat
;
};
template
<
>
struct
PackedOn16Bytes
<
half
>
{
using
Type
=
PackedHalf
;
};
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
using
PackedBFloat16
=
union
{
int4
packed
;
__nv_bfloat162
unpacked
[
4
];
};
template
<
>
struct
PackedOn16Bytes
<
__nv_bfloat16
>
{
using
Type
=
PackedBFloat16
;
};
#endif
// add two 128b data
template
<
typename
T
>
inline
__device__
int4
add128b
(
T
&
a
,
T
&
b
)
{
T
c
;
c
.
unpacked
[
0
]
=
a
.
unpacked
[
0
]
+
b
.
unpacked
[
0
];
c
.
unpacked
[
1
]
=
a
.
unpacked
[
1
]
+
b
.
unpacked
[
1
];
c
.
unpacked
[
2
]
=
a
.
unpacked
[
2
]
+
b
.
unpacked
[
2
];
c
.
unpacked
[
3
]
=
a
.
unpacked
[
3
]
+
b
.
unpacked
[
3
];
return
c
.
packed
;
}
__inline__
__device__
void
multi_gpu_barrier
(
uint32_t
**
signals
,
uint32_t
const
flag
,
size_t
const
local_rank
,
size_t
const
world_size
,
int
const
tidx
,
int
const
bidx
)
{
// After this function, at least one block in each GPU has reached the barrier
if
(
tidx
<
world_size
)
{
// we can think of signals having the shape [world_size, world_size]
// Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
size_t
offset
=
(
flag
%
2
)
?
world_size
:
0
;
if
(
bidx
==
0
)
{
st_flag_release
(
flag
,
signals
[
tidx
]
+
offset
+
local_rank
);
}
// All blocks check that corresponding block 0 on other GPUs have set the flag
// No deadlock because block #0 is always the first block started
uint32_t
*
peer_barrier_d
=
signals
[
local_rank
]
+
offset
+
tidx
;
while
(
ld_flag_acquire
(
peer_barrier_d
)
!=
flag
)
{
}
}
__syncthreads
();
}
template
<
typename
T
,
int
RANKS_PER_NODE
>
/* COPY_INPUT = false, PUSH_MODE = false */
static
__global__
void
oneShotAllReduceKernel
(
AllReduceParams
params
)
{
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
// The message is partitioned into chunks as detailed below:
// message
// |-------------------|
// GPU 0 | B0 | B1 | B2 | B3 |
// GPU 1 | B0 | B1 | B2 | B3 |
//
// Here the step-by-step behavior of one block:
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
//
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
//
// With PUSH_MODE, we consider that the shared buffer is of size:
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
//
// Here the step-by-step behavior of one block:
// 1. B0 push the chunk is it responsible for into all other GPUs:
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
// 2. block sync so the block is shared by other GPUs
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
int
const
bidx
=
blockIdx
.
x
;
int
const
tidx
=
threadIdx
.
x
;
// The number of elements packed into one for comms
static
constexpr
int
NUM_ELTS
=
16
/
sizeof
(
T
);
// Packed data type for comms
using
PackedStruct
=
typename
PackedOn16Bytes
<
T
>::
Type
;
// The source pointers. Distributed round-robin for the different warps.
T
const
*
buffers
[
RANKS_PER_NODE
];
// Start and end offsets of the thread
size_t
chunk_start
=
bidx
*
params
.
elts_per_block
+
tidx
*
NUM_ELTS
;
size_t
chunk_end
=
std
::
min
((
bidx
+
1
)
*
params
.
elts_per_block
,
params
.
elts_per_rank
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
int
rank
=
(
params
.
local_rank
+
ii
)
%
RANKS_PER_NODE
;
buffers
[
ii
]
=
reinterpret_cast
<
T
*>
(
params
.
peer_comm_buffer_ptrs
[
rank
]);
}
multi_gpu_barrier
(
params
.
peer_barrier_ptrs_in
,
params
.
barrier_flag
,
params
.
local_rank
,
RANKS_PER_NODE
,
tidx
,
bidx
);
// Each block accumulates the values from the different GPUs on the same node.
for
(
size_t
iter_offset
=
chunk_start
;
iter_offset
<
chunk_end
;
iter_offset
+=
blockDim
.
x
*
NUM_ELTS
)
{
// Iterate over the different ranks/devices on the node to load the values.
PackedStruct
vals
[
RANKS_PER_NODE
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
RANKS_PER_NODE
;
++
ii
)
{
vals
[
ii
].
packed
=
*
reinterpret_cast
<
int4
const
*>
(
&
buffers
[
ii
][
iter_offset
]);
}
// Sum the values from the different ranks.
PackedStruct
sums
;
sums
.
packed
=
{
0
,
0
,
0
,
0
};
#pragma unroll
for
(
int
rank
=
0
;
rank
<
RANKS_PER_NODE
;
++
rank
)
{
// Always reduce from rank 0 to ensure stable reduce order.
int
ii
=
(
rank
+
RANKS_PER_NODE
-
params
.
local_rank
)
%
RANKS_PER_NODE
;
sums
.
packed
=
add128b
(
sums
,
vals
[
ii
]);
}
// Store to the destination buffer.
*
reinterpret_cast
<
int4
*>
(
&
reinterpret_cast
<
T
*>
(
params
.
local_output_buffer_ptr
)[
iter_offset
])
=
sums
.
packed
;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
int
divUp
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
inline
int
roundUp
(
int
a
,
int
n
)
{
return
divUp
(
a
,
n
)
*
n
;
}
std
::
tuple
<
int
,
int
>
kernelLaunchConfig
(
AllReduceStrategyType
algo
,
AllReduceParams
&
params
,
size_t
elts_per_thread
)
{
int
blocks_per_grid
=
1
,
threads_per_block
=
DEFAULT_BLOCK_SIZE
;
switch
(
algo
)
{
case
AllReduceStrategyType
::
ONESHOT
:
{
assert
(
params
.
elts_total
%
elts_per_thread
==
0
);
size_t
const
total_threads
=
roundUp
(
params
.
elts_total
/
elts_per_thread
,
WARP_SIZE
);
threads_per_block
=
std
::
min
(
DEFAULT_BLOCK_SIZE
,
total_threads
);
blocks_per_grid
=
std
::
min
(
static_cast
<
int
>
(
MAX_ALL_REDUCE_BLOCKS
),
divUp
(
total_threads
,
threads_per_block
));
params
.
elts_per_block
=
roundUp
(
divUp
(
params
.
elts_total
,
blocks_per_grid
),
elts_per_thread
);
params
.
elts_per_rank
=
params
.
elts_total
;
break
;
}
default:
assert
(
false
&&
"Algorithm not supported here."
);
}
return
std
::
make_tuple
(
blocks_per_grid
,
threads_per_block
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
RANKS_PER_NODE
>
void
dispatchARKernels
(
AllReduceStrategyType
algo
,
AllReduceParams
&
param
,
int
blocks_per_grid
,
int
threads_per_block
,
cudaStream_t
stream
)
{
oneShotAllReduceKernel
<
T
,
RANKS_PER_NODE
><<<
blocks_per_grid
,
threads_per_block
,
0
,
stream
>>>
(
param
);
}
template
<
typename
T
>
void
invokeOneOrTwoShotAllReduceKernel
(
AllReduceParams
&
param
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
void
*
buffer
=
reinterpret_cast
<
void
*>
(
param
.
peer_comm_buffer_ptrs
[
param
.
rank
]);
void
*
local_inp_buffer
=
param
.
local_input_buffer_ptr
;
CHECK_CUDA_SUCCESS
(
cudaMemcpyAsync
(
buffer
,
local_inp_buffer
,
param
.
elts_total
*
param
.
elts_size
,
cudaMemcpyDeviceToDevice
,
stream
));
assert
(
strat
==
AllReduceStrategyType
::
ONESHOT
&&
"Custom allreduce only support oneshot"
);
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
size_t
elts_per_thread
=
16
/
sizeof
(
T
);
auto
[
blocks_per_grid
,
threads_per_block
]
=
kernelLaunchConfig
(
strat
,
param
,
elts_per_thread
);
switch
(
param
.
ranks_per_node
)
{
case
2
:
dispatchARKernels
<
T
,
2
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
case
4
:
dispatchARKernels
<
T
,
4
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
case
6
:
dispatchARKernels
<
T
,
6
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
case
8
:
dispatchARKernels
<
T
,
8
>
(
strat
,
param
,
blocks_per_grid
,
threads_per_block
,
stream
);
break
;
default:
break
;
}
CHECK_CUDA_SUCCESS
(
cudaGetLastError
());
}
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
)
{
if
(
params
.
elts_total
==
0
)
{
return
;
}
switch
(
data_type
)
{
case
at
::
ScalarType
::
Float
:
invokeOneOrTwoShotAllReduceKernel
<
float
>
(
params
,
strat
,
stream
);
break
;
case
at
::
ScalarType
::
Half
:
invokeOneOrTwoShotAllReduceKernel
<
half
>
(
params
,
strat
,
stream
);
break
;
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
invokeOneOrTwoShotAllReduceKernel
<
__nv_bfloat16
>
(
params
,
strat
,
stream
);
break
;
#endif
default:
assert
(
false
&&
"Unsupported data type"
);
}
}
}
// namespace trt_llm
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh
0 → 100644
View file @
e04d3f28
// reference:
// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_fp16.h>
#include <stdint.h>
#include <torch/all.h>
#include "utils.hpp"
namespace
trt_llm
{
constexpr
size_t
WARP_SIZE
=
32
;
constexpr
size_t
MAX_ALL_REDUCE_BLOCKS
=
24
;
constexpr
size_t
MAX_RANKS_PER_NODE
=
8
;
constexpr
size_t
DEFAULT_BLOCK_SIZE
=
1024
;
enum
class
AllReduceStrategyType
:
int8_t
{
RING
=
0
,
ONESHOT
=
1
,
TWOSHOT
=
2
,
AUTO
=
3
,
};
struct
AllReduceParams
{
size_t
elts_size
;
size_t
elts_total
;
size_t
elts_per_rank
;
size_t
elts_per_block
;
size_t
rank_offset
;
size_t
ranks_per_node
,
rank
,
local_rank
;
uint32_t
barrier_flag
;
uint32_t
*
peer_barrier_ptrs_in
[
MAX_RANKS_PER_NODE
];
uint32_t
*
peer_barrier_ptrs_out
[
MAX_RANKS_PER_NODE
];
void
*
peer_comm_buffer_ptrs
[
MAX_RANKS_PER_NODE
];
void
*
local_input_buffer_ptr
;
void
*
local_output_buffer_ptr
;
};
inline
size_t
GetMaxRequiredWorkspaceSize
(
int
world_size
)
{
if
(
world_size
<=
2
)
{
return
16
*
1000
*
1000
;
}
return
8
*
1000
*
1000
;
}
inline
AllReduceStrategyType
SelectImplementation
(
size_t
message_size
,
int
world_size
)
{
const
size_t
maxWorkspaceSize
=
GetMaxRequiredWorkspaceSize
(
world_size
);
if
(
message_size
>
maxWorkspaceSize
)
{
assert
(
false
&&
"Custom allreduce do not ring currently"
);
return
AllReduceStrategyType
::
RING
;
}
if
(
world_size
<=
2
)
{
return
AllReduceStrategyType
::
ONESHOT
;
}
if
(
world_size
<=
4
)
{
if
(
message_size
<
1
*
1000
*
1000
)
{
return
AllReduceStrategyType
::
ONESHOT
;
}
assert
(
false
&&
"Custom allreduce do not twoshot currently"
);
return
AllReduceStrategyType
::
TWOSHOT
;
}
if
(
message_size
<
500
*
1000
)
{
return
AllReduceStrategyType
::
ONESHOT
;
}
assert
(
false
&&
"Custom allreduce do not twoshot currently"
);
return
AllReduceStrategyType
::
TWOSHOT
;
}
void
trtCustomAllReduce
(
AllReduceParams
&
params
,
at
::
ScalarType
data_type
,
AllReduceStrategyType
strat
,
cudaStream_t
stream
);
}
// namespace trt_llm
sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu
0 → 100644
View file @
e04d3f28
// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h
#include <c10/cuda/CUDAStream.h>
#include <cassert>
#include <iostream>
#include <sstream>
#include <unordered_map>
#include "trt_reduce_internal.cuh"
using
namespace
trt_llm
;
using
fptr_t
=
int64_t
;
class
AllReduceMeta
{
public:
AllReduceMeta
(
int64_t
rank_id
,
int64_t
world_size
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
this
->
rank_id
=
(
int
)
rank_id
;
this
->
world_size
=
(
int
)
world_size
;
this
->
buffers
=
buffers
;
this
->
barrier_in
=
barrier_in
;
this
->
barrier_out
=
barrier_out
;
}
public:
int
world_size
;
int
rank_id
;
std
::
vector
<
fptr_t
>
buffers
;
std
::
vector
<
fptr_t
>
barrier_in
;
std
::
vector
<
fptr_t
>
barrier_out
;
int
barrier_flag
=
1
;
};
// Get the number of bits for a given data type.
inline
int
get_bits
(
at
::
ScalarType
dtype
)
{
switch
(
dtype
)
{
case
at
::
ScalarType
::
Float
:
return
32
;
case
at
::
ScalarType
::
Half
:
case
at
::
ScalarType
::
BFloat16
:
return
16
;
default:
assert
(
false
&&
"Unsupported data type"
);
}
}
// Check if customized all-reduce kernels can be applied.
inline
bool
CanApplyCustomAllReduce
(
int64_t
num_elements
,
at
::
ScalarType
dtype
)
{
// The customized all-reduce kernel has the following requirement(s).
return
num_elements
%
(
16
/
((
get_bits
(
dtype
)
+
7
)
/
8
))
==
0
;
}
fptr_t
init_custom_ar
(
int64_t
rank_id
,
int64_t
world_size
,
const
std
::
vector
<
fptr_t
>&
buffers
,
const
std
::
vector
<
fptr_t
>&
barrier_in
,
const
std
::
vector
<
fptr_t
>&
barrier_out
)
{
auto
m
=
new
AllReduceMeta
(
rank_id
,
world_size
,
buffers
,
barrier_in
,
barrier_out
);
return
(
fptr_t
)
m
;
}
void
dispose
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
delete
fa
;
}
void
all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
AllReduceMeta
*
m
=
reinterpret_cast
<
AllReduceMeta
*>
(
_fa
);
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
num_elements
=
inp
.
numel
();
auto
dtype
=
inp
.
scalar_type
();
AllReduceStrategyType
strategy
=
SelectImplementation
(
num_elements
*
((
get_bits
(
dtype
)
+
7
)
/
8
),
m
->
world_size
);
// should be gurantee in python code
assert
(
strategy
==
AllReduceStrategyType
::
ONESHOT
);
assert
(
CanApplyCustomAllReduce
(
num_elements
,
dtype
));
// Initialize the all-reduce kernel arguments.
int
world_size
=
m
->
world_size
;
AllReduceParams
params
;
params
.
ranks_per_node
=
world_size
;
params
.
rank
=
m
->
rank_id
;
params
.
local_rank
=
m
->
rank_id
;
params
.
local_input_buffer_ptr
=
inp
.
data_ptr
();
params
.
local_output_buffer_ptr
=
out
.
data_ptr
();
params
.
elts_total
=
inp
.
numel
();
params
.
elts_size
=
inp
.
element_size
();
params
.
barrier_flag
=
++
(
m
->
barrier_flag
);
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
params
.
peer_comm_buffer_ptrs
[
i
]
=
reinterpret_cast
<
void
*>
(
m
->
buffers
[
i
]);
}
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
params
.
peer_barrier_ptrs_in
[
i
]
=
reinterpret_cast
<
uint32_t
*>
(
m
->
barrier_in
[
i
]);
}
for
(
int
i
=
0
;
i
<
world_size
;
++
i
)
{
params
.
peer_barrier_ptrs_out
[
i
]
=
reinterpret_cast
<
uint32_t
*>
(
m
->
barrier_out
[
i
]);
}
auto
data_type
=
out
.
scalar_type
();
trtCustomAllReduce
(
params
,
data_type
,
strategy
,
stream
);
}
sgl-kernel/src/sgl-kernel/csrc/utils.hpp
0 → 100644
View file @
e04d3f28
#pragma once
#include <torch/extension.h>
#include <sstream>
struct
cuda_error
:
public
std
::
runtime_error
{
/**
* @brief Constructs a `cuda_error` object with the given `message`.
*
* @param message The error char array used to construct `cuda_error`
*/
cuda_error
(
const
char
*
message
)
:
std
::
runtime_error
(
message
)
{}
/**
* @brief Constructs a `cuda_error` object with the given `message` string.
*
* @param message The `std::string` used to construct `cuda_error`
*/
cuda_error
(
std
::
string
const
&
message
)
:
cuda_error
{
message
.
c_str
()}
{}
};
#define CHECK_CUDA_SUCCESS(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
std::stringstream _message; \
auto s = cudaGetErrorString(e); \
_message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \
throw cuda_error(_message.str()); \
} \
} while (0)
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \
CHECK_IS_CUDA(x); \
CHECK_IS_CONTIGUOUS(x)
sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc
View file @
e04d3f28
#include <torch/extension.h>
torch
::
Tensor
warp_reduce_cuda
(
torch
::
Tensor
input
);
#include "utils.hpp"
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch
::
Tensor
warp_reduce_cuda
(
torch
::
Tensor
input
);
torch
::
Tensor
warp_reduce
(
torch
::
Tensor
input
)
{
CHECK_INPUT
(
input
);
CHECK_
CUDA_
INPUT
(
input
);
return
warp_reduce_cuda
(
input
);
}
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
e04d3f28
from
.custom_reduce_cuda
import
all_reduce
as
_all_reduce
from
.custom_reduce_cuda
import
dispose
as
_dispose
from
.custom_reduce_cuda
import
init_custom_ar
as
_init_custom_ar
from
.warp_reduce_cuda
import
reduce
as
_reduce
def
warp_reduce
(
input_tensor
):
return
_reduce
(
input_tensor
)
def
init_custom_reduce
(
rank_id
,
num_devices
,
buffers
,
barrier_in
,
barrier_out
):
return
_init_custom_ar
(
rank_id
,
num_devices
,
buffers
,
barrier_in
,
barrier_out
)
def
custom_dispose
(
fa
):
_dispose
(
fa
)
def
custom_reduce
(
fa
,
inp
,
out
):
_all_reduce
(
fa
,
inp
,
out
)
sgl-kernel/tests/test_trt_reduce.py
0 → 100644
View file @
e04d3f28
import
ctypes
import
logging
import
os
import
random
import
socket
import
time
import
unittest
from
typing
import
Any
,
List
,
Optional
,
Union
import
ray
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
from
vllm
import
_custom_ops
as
vllm_ops
from
sglang.srt.distributed.device_communicators.cuda_wrapper
import
CudaRTLibrary
logger
=
logging
.
getLogger
(
__name__
)
def
get_open_port
()
->
int
:
# try ipv4
try
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
except
OSError
:
# try ipv6
with
socket
.
socket
(
socket
.
AF_INET6
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
"127.0.0.1"
,
0
))
return
s
.
getsockname
()[
1
]
def
multi_process_parallel
(
world_size
:
int
,
cls
:
Any
,
test_target
:
Any
,
)
->
None
:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray
.
init
(
log_to_driver
=
True
)
distributed_init_port
=
get_open_port
()
refs
=
[]
for
rank
in
range
(
world_size
):
refs
.
append
(
test_target
.
remote
(
cls
,
world_size
,
rank
,
distributed_init_port
))
ray
.
get
(
refs
)
ray
.
shutdown
()
class
TestCustomAllReduce
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
random
.
seed
(
42
)
cls
.
test_sizes
=
{
2
:
[
512
,
4096
,
32768
,
262144
,
2097152
],
4
:
[
512
,
4096
,
32768
,
131072
],
6
:
[
512
,
4096
,
32768
,
65536
],
8
:
[
512
,
4096
,
32768
,
65536
],
}
cls
.
world_sizes
=
[
2
,
4
,
6
,
8
]
@
staticmethod
def
create_shared_buffer
(
size_in_bytes
:
int
,
group
:
Optional
[
ProcessGroup
]
=
None
)
->
List
[
int
]:
"""
Creates a shared buffer and returns a list of pointers
representing the buffer on all processes in the group.
"""
lib
=
CudaRTLibrary
()
pointer
=
lib
.
cudaMalloc
(
size_in_bytes
)
handle
=
lib
.
cudaIpcGetMemHandle
(
pointer
)
world_size
=
dist
.
get_world_size
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
handles
=
[
None
]
*
world_size
dist
.
all_gather_object
(
handles
,
handle
,
group
=
group
)
pointers
:
List
[
int
]
=
[]
for
i
,
h
in
enumerate
(
handles
):
if
i
==
rank
:
pointers
.
append
(
pointer
.
value
)
# type: ignore
else
:
pointers
.
append
(
lib
.
cudaIpcOpenMemHandle
(
h
).
value
)
# type: ignore
return
pointers
@
staticmethod
def
free_shared_buffer
(
pointers
:
List
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
lib
=
CudaRTLibrary
()
lib
.
cudaFree
(
ctypes
.
c_void_p
(
pointers
[
rank
]))
def
test_correctness
(
self
):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
continue
multi_process_parallel
(
world_size
,
self
,
self
.
correctness
)
def
test_performance
(
self
):
for
world_size
in
self
.
world_sizes
:
if
world_size
>
torch
.
cuda
.
device_count
():
continue
multi_process_parallel
(
world_size
,
self
,
self
.
performance
)
def
init_custom_allreduce
(
self
,
rank
,
world_size
,
group
):
import
sgl_kernel
buffer_max_size
=
8
*
1024
*
1024
barrier_max_size
=
8
*
(
24
+
2
)
*
8
self
.
buffer_ptrs
=
self
.
create_shared_buffer
(
buffer_max_size
,
group
=
group
)
self
.
barrier_in_ptrs
=
self
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
self
.
barrier_out_ptrs
=
self
.
create_shared_buffer
(
barrier_max_size
,
group
=
group
)
self
.
custom_ptr
=
sgl_kernel
.
ops
.
init_custom_reduce
(
rank
,
world_size
,
self
.
buffer_ptrs
,
self
.
barrier_in_ptrs
,
self
.
barrier_out_ptrs
,
)
def
custom_allreduce
(
self
,
inp
,
out
):
import
sgl_kernel
sgl_kernel
.
ops
.
custom_reduce
(
self
.
custom_ptr
,
inp
,
out
)
def
free_custom_allreduce
(
self
,
group
):
import
sgl_kernel
self
.
free_shared_buffer
(
self
.
buffer_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
barrier_in_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
barrier_out_ptrs
,
group
)
sgl_kernel
.
ops
.
custom_dispose
(
self
.
custom_ptr
)
def
init_vllm_allreduce
(
self
,
rank
,
group
):
self
.
vllm_rank
=
rank
self
.
vllm_max_size
=
8
*
1024
*
1024
self
.
vllm_meta_ptrs
=
self
.
create_shared_buffer
(
vllm_ops
.
meta_size
()
+
self
.
vllm_max_size
,
group
=
group
)
self
.
vllm_buffer_ptrs
=
self
.
create_shared_buffer
(
self
.
vllm_max_size
,
group
=
group
)
self
.
vllm_rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
)
self
.
vllm_ptr
=
vllm_ops
.
init_custom_ar
(
self
.
vllm_meta_ptrs
,
self
.
vllm_rank_data
,
rank
,
True
)
vllm_ops
.
register_buffer
(
self
.
vllm_ptr
,
self
.
vllm_buffer_ptrs
)
def
vllm_allreduce
(
self
,
inp
,
out
):
vllm_ops
.
all_reduce
(
self
.
vllm_ptr
,
inp
,
out
,
self
.
vllm_buffer_ptrs
[
self
.
vllm_rank
],
self
.
vllm_max_size
,
)
def
free_vllm_allreduce
(
self
,
group
):
vllm_ops
.
dispose
(
self
.
vllm_ptr
)
self
.
free_shared_buffer
(
self
.
vllm_meta_ptrs
,
group
)
self
.
free_shared_buffer
(
self
.
vllm_buffer_ptrs
,
group
)
@
staticmethod
def
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
ranks
=
[
i
for
i
in
range
(
world_size
)]
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
dist
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
distributed_init_method
,
rank
=
rank
,
world_size
=
world_size
,
)
group
=
torch
.
distributed
.
new_group
(
ranks
,
backend
=
"gloo"
)
return
group
# compare result with torch.distributed
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
correctness
(
self
,
world_size
,
rank
,
distributed_init_port
):
group
=
self
.
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
)
self
.
init_custom_allreduce
(
rank
=
rank
,
world_size
=
world_size
,
group
=
group
)
test_loop
=
10
for
sz
in
self
.
test_sizes
[
world_size
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
_
in
range
(
test_loop
):
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
()
)
out1
=
torch
.
empty_like
(
inp1
)
self
.
custom_allreduce
(
inp1
,
out1
)
dist
.
all_reduce
(
inp1
,
group
=
group
)
torch
.
testing
.
assert_close
(
out1
,
inp1
)
self
.
free_custom_allreduce
(
group
)
# compare performance with vllm
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
performance
(
self
,
world_size
,
rank
,
distributed_init_port
):
group
=
self
.
init_distributed_env
(
world_size
,
rank
,
distributed_init_port
)
self
.
init_vllm_allreduce
(
rank
,
group
)
self
.
init_custom_allreduce
(
rank
=
rank
,
world_size
=
world_size
,
group
=
group
)
for
sz
in
self
.
test_sizes
[
world_size
]:
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,),
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()
)
out1
=
torch
.
empty_like
(
inp1
)
test_loop
=
5000
start
=
time
.
time
()
for
_
in
range
(
test_loop
):
self
.
custom_allreduce
(
inp1
,
out1
)
elapse_custom
=
time
.
time
()
-
start
start
=
time
.
time
()
for
_
in
range
(
test_loop
):
self
.
vllm_allreduce
(
inp1
,
out1
)
elapse_vllm
=
time
.
time
()
-
start
if
rank
==
0
:
logger
.
warning
(
f
"test_size =
{
sz
}
, world_size =
{
world_size
}
, "
f
"vllm time =
{
elapse_vllm
*
1000
/
test_loop
:.
4
f
}
us,"
f
"custom time =
{
elapse_custom
*
1000
/
test_loop
:.
4
f
}
us"
)
self
.
free_custom_allreduce
(
group
)
self
.
free_vllm_allreduce
(
group
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment