Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9602c2aa
Unverified
Commit
9602c2aa
authored
Jan 31, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 31, 2025
Browse files
keep the parts needed for moe_kernels (#3218)
parent
e81d7f11
Changes
24
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1845 additions
and
2823 deletions
+1845
-2823
sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt
sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt
+0
-22
sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp
sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp
+0
-0
sgl-kernel/3rdparty/tensorrt_llm/common/assert.h
sgl-kernel/3rdparty/tensorrt_llm/common/assert.h
+92
-0
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp
...kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp
+187
-0
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h
+138
-0
sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h
sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h
+239
-0
sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp
...kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp
+0
-84
sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h
sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h
+641
-0
sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h
...ernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h
+0
-36
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp
+0
-214
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h
+0
-60
sgl-kernel/3rdparty/tensorrt_llm/common/logger.h
sgl-kernel/3rdparty/tensorrt_llm/common/logger.h
+190
-0
sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h
sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h
+0
-37
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu
+0
-906
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h
+0
-292
sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp
sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp
+0
-588
sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h
sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h
+0
-46
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp
+0
-323
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h
+0
-215
sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h
sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h
+358
-0
No files found.
sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt
deleted
100644 → 0
View file @
e81d7f11
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
file
(
GLOB SRCS *.cpp
)
file
(
GLOB CU_SRCS *.cu
)
add_library
(
common_src OBJECT
${
SRCS
}
${
CU_SRCS
}
)
set_property
(
TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON
)
set_property
(
TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON
)
sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp
100755 → 100644
View file @
9602c2aa
File mode changed from 100755 to 100644
sgl-kernel/3rdparty/tensorrt_llm/common/assert.h
0 → 100644
View file @
9602c2aa
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/common/tllmException.h"
#include <string>
namespace
tensorrt_llm
::
common
{
[[
noreturn
]]
inline
void
throwRuntimeError
(
char
const
*
const
file
,
int
const
line
,
std
::
string
const
&
info
=
""
)
{
throw
TllmException
(
file
,
line
,
fmtstr
(
"[TensorRT-LLM][ERROR] Assertion failed: %s"
,
info
.
c_str
()));
}
}
// namespace tensorrt_llm::common
class
DebugConfig
{
public:
static
bool
isCheckDebugEnabled
();
};
#if defined(_WIN32)
#define TLLM_LIKELY(x) (__assume((x) == 1), (x))
#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x))
#else
#define TLLM_LIKELY(x) __builtin_expect((x), 1)
#define TLLM_UNLIKELY(x) __builtin_expect((x), 0)
#endif
#define TLLM_CHECK(val) \
do \
{ \
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
} while (0)
#define TLLM_CHECK_WITH_INFO(val, info, ...) \
do \
{ \
TLLM_LIKELY(static_cast<bool>(val)) \
? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError( \
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
} while (0)
#define TLLM_CHECK_DEBUG(val) \
do \
{ \
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
{ \
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
} \
} while (0)
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \
do \
{ \
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
{ \
TLLM_LIKELY(static_cast<bool>(val)) \
? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError( \
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
} \
} while (0)
#define TLLM_THROW(...) \
do \
{ \
throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \
} while (0)
#define TLLM_WRAP(ex) \
NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what())
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp
0 → 100644
View file @
9602c2aa
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define CUDA_LIB_NAME "cuda"
#if defined(_WIN32)
#include <windows.h>
#define dllOpen(name) LoadLibrary("nv" name ".dll")
#define dllClose(handle) FreeLibrary(static_cast<HMODULE>(handle))
#define dllGetSym(handle, name) static_cast<void*>(GetProcAddress(static_cast<HMODULE>(handle), name))
#else // For non-Windows platforms
#include <dlfcn.h>
#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY)
#define dllClose(handle) dlclose(handle)
#define dllGetSym(handle, name) dlsym(handle, name)
#endif // defined(_WIN32)
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
namespace
tensorrt_llm
::
common
{
std
::
shared_ptr
<
CUDADriverWrapper
>
CUDADriverWrapper
::
getInstance
()
{
static
std
::
mutex
mutex
;
static
std
::
weak_ptr
<
CUDADriverWrapper
>
instance
;
std
::
shared_ptr
<
CUDADriverWrapper
>
result
=
instance
.
lock
();
if
(
result
)
{
return
result
;
}
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
result
=
instance
.
lock
();
if
(
!
result
)
{
result
=
std
::
shared_ptr
<
CUDADriverWrapper
>
(
new
CUDADriverWrapper
());
instance
=
result
;
}
return
result
;
}
CUDADriverWrapper
::
CUDADriverWrapper
()
:
handle
(
dllOpen
(
CUDA_LIB_NAME
))
{
TLLM_CHECK_WITH_INFO
(
handle
!=
nullptr
,
"CUDA driver library is not open correctly."
);
auto
load_sym
=
[](
void
*
handle
,
char
const
*
name
)
{
void
*
ret
=
dllGetSym
(
handle
,
name
);
return
ret
;
};
*
reinterpret_cast
<
void
**>
(
&
_cuGetErrorName
)
=
load_sym
(
handle
,
"cuGetErrorName"
);
*
reinterpret_cast
<
void
**>
(
&
_cuGetErrorMessage
)
=
load_sym
(
handle
,
"cuGetErrorMessage"
);
*
reinterpret_cast
<
void
**>
(
&
_cuFuncSetAttribute
)
=
load_sym
(
handle
,
"cuFuncSetAttribute"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkComplete
)
=
load_sym
(
handle
,
"cuLinkComplete"
);
*
reinterpret_cast
<
void
**>
(
&
_cuModuleUnload
)
=
load_sym
(
handle
,
"cuModuleUnload"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkDestroy
)
=
load_sym
(
handle
,
"cuLinkDestroy"
);
*
reinterpret_cast
<
void
**>
(
&
_cuModuleLoadData
)
=
load_sym
(
handle
,
"cuModuleLoadData"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkCreate
)
=
load_sym
(
handle
,
"cuLinkCreate_v2"
);
*
reinterpret_cast
<
void
**>
(
&
_cuModuleGetFunction
)
=
load_sym
(
handle
,
"cuModuleGetFunction"
);
*
reinterpret_cast
<
void
**>
(
&
_cuModuleGetGlobal
)
=
load_sym
(
handle
,
"cuModuleGetGlobal_v2"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkAddFile
)
=
load_sym
(
handle
,
"cuLinkAddFile_v2"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLinkAddData
)
=
load_sym
(
handle
,
"cuLinkAddData_v2"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLaunchCooperativeKernel
)
=
load_sym
(
handle
,
"cuLaunchCooperativeKernel"
);
*
reinterpret_cast
<
void
**>
(
&
_cuLaunchKernel
)
=
load_sym
(
handle
,
"cuLaunchKernel"
);
*
reinterpret_cast
<
void
**>
(
&
_cuTensorMapEncodeTiled
)
=
load_sym
(
handle
,
"cuTensorMapEncodeTiled"
);
*
reinterpret_cast
<
void
**>
(
&
_cuMemcpyDtoH
)
=
load_sym
(
handle
,
"cuMemcpyDtoH_v2"
);
}
CUDADriverWrapper
::~
CUDADriverWrapper
()
{
dllClose
(
handle
);
}
CUresult
CUDADriverWrapper
::
cuGetErrorName
(
CUresult
error
,
char
const
**
pStr
)
const
{
return
(
*
_cuGetErrorName
)(
error
,
pStr
);
}
CUresult
CUDADriverWrapper
::
cuGetErrorMessage
(
CUresult
error
,
char
const
**
pStr
)
const
{
return
(
*
_cuGetErrorMessage
)(
error
,
pStr
);
}
CUresult
CUDADriverWrapper
::
cuFuncSetAttribute
(
CUfunction
hfunc
,
CUfunction_attribute
attrib
,
int
value
)
const
{
return
(
*
_cuFuncSetAttribute
)(
hfunc
,
attrib
,
value
);
}
CUresult
CUDADriverWrapper
::
cuLinkComplete
(
CUlinkState
state
,
void
**
cubinOut
,
size_t
*
sizeOut
)
const
{
return
(
*
_cuLinkComplete
)(
state
,
cubinOut
,
sizeOut
);
}
CUresult
CUDADriverWrapper
::
cuModuleUnload
(
CUmodule
hmod
)
const
{
return
(
*
_cuModuleUnload
)(
hmod
);
}
CUresult
CUDADriverWrapper
::
cuLinkDestroy
(
CUlinkState
state
)
const
{
return
(
*
_cuLinkDestroy
)(
state
);
}
CUresult
CUDADriverWrapper
::
cuModuleLoadData
(
CUmodule
*
module
,
void
const
*
image
)
const
{
return
(
*
_cuModuleLoadData
)(
module
,
image
);
}
CUresult
CUDADriverWrapper
::
cuLinkCreate
(
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
,
CUlinkState
*
stateOut
)
const
{
return
(
*
_cuLinkCreate
)(
numOptions
,
options
,
optionValues
,
stateOut
);
}
CUresult
CUDADriverWrapper
::
cuModuleGetFunction
(
CUfunction
*
hfunc
,
CUmodule
hmod
,
char
const
*
name
)
const
{
return
(
*
_cuModuleGetFunction
)(
hfunc
,
hmod
,
name
);
}
CUresult
CUDADriverWrapper
::
cuModuleGetGlobal
(
CUdeviceptr
*
dptr
,
size_t
*
bytes
,
CUmodule
hmod
,
char
const
*
name
)
const
{
return
(
*
_cuModuleGetGlobal
)(
dptr
,
bytes
,
hmod
,
name
);
}
CUresult
CUDADriverWrapper
::
cuLinkAddFile
(
CUlinkState
state
,
CUjitInputType
type
,
char
const
*
path
,
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
)
const
{
return
(
*
_cuLinkAddFile
)(
state
,
type
,
path
,
numOptions
,
options
,
optionValues
);
}
CUresult
CUDADriverWrapper
::
cuLinkAddData
(
CUlinkState
state
,
CUjitInputType
type
,
void
*
data
,
size_t
size
,
char
const
*
name
,
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
)
const
{
return
(
*
_cuLinkAddData
)(
state
,
type
,
data
,
size
,
name
,
numOptions
,
options
,
optionValues
);
}
CUresult
CUDADriverWrapper
::
cuLaunchCooperativeKernel
(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
)
const
{
return
(
*
_cuLaunchCooperativeKernel
)(
f
,
gridDimX
,
gridDimY
,
gridDimZ
,
blockDimX
,
blockDimY
,
blockDimZ
,
sharedMemBytes
,
hStream
,
kernelParams
);
}
CUresult
CUDADriverWrapper
::
cuLaunchKernel
(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
,
void
**
extra
)
const
{
return
(
*
_cuLaunchKernel
)(
f
,
gridDimX
,
gridDimY
,
gridDimZ
,
blockDimX
,
blockDimY
,
blockDimZ
,
sharedMemBytes
,
hStream
,
kernelParams
,
extra
);
}
CUresult
CUDADriverWrapper
::
cuTensorMapEncodeTiled
(
CUtensorMap
*
tensorMap
,
CUtensorMapDataType
tensorDataType
,
cuuint32_t
tensorRank
,
void
*
globalAddress
,
cuuint64_t
const
*
globalDim
,
cuuint64_t
const
*
globalStrides
,
cuuint32_t
const
*
boxDim
,
cuuint32_t
const
*
elementStrides
,
CUtensorMapInterleave
interleave
,
CUtensorMapSwizzle
swizzle
,
CUtensorMapL2promotion
l2Promotion
,
CUtensorMapFloatOOBfill
oobFill
)
const
{
return
(
*
_cuTensorMapEncodeTiled
)(
tensorMap
,
tensorDataType
,
tensorRank
,
globalAddress
,
globalDim
,
globalStrides
,
boxDim
,
elementStrides
,
interleave
,
swizzle
,
l2Promotion
,
oobFill
);
}
CUresult
CUDADriverWrapper
::
cuMemcpyDtoH
(
void
*
dstHost
,
CUdeviceptr
srcDevice
,
size_t
ByteCount
)
const
{
return
(
*
_cuMemcpyDtoH
)(
dstHost
,
srcDevice
,
ByteCount
);
}
}
// namespace tensorrt_llm::common
sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h
0 → 100644
View file @
9602c2aa
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <memory>
#include <mutex>
namespace
tensorrt_llm
::
common
{
class
CUDADriverWrapper
{
public:
static
std
::
shared_ptr
<
CUDADriverWrapper
>
getInstance
();
~
CUDADriverWrapper
();
CUDADriverWrapper
(
CUDADriverWrapper
const
&
)
=
delete
;
CUDADriverWrapper
operator
=
(
CUDADriverWrapper
const
&
)
=
delete
;
CUDADriverWrapper
(
CUDADriverWrapper
&&
)
=
delete
;
CUDADriverWrapper
operator
=
(
CUDADriverWrapper
&&
)
=
delete
;
CUresult
cuGetErrorName
(
CUresult
error
,
char
const
**
pStr
)
const
;
CUresult
cuGetErrorMessage
(
CUresult
error
,
char
const
**
pStr
)
const
;
CUresult
cuFuncSetAttribute
(
CUfunction
hfunc
,
CUfunction_attribute
attrib
,
int
value
)
const
;
CUresult
cuLinkComplete
(
CUlinkState
state
,
void
**
cubinOut
,
size_t
*
sizeOut
)
const
;
CUresult
cuModuleUnload
(
CUmodule
hmod
)
const
;
CUresult
cuLinkDestroy
(
CUlinkState
state
)
const
;
CUresult
cuModuleLoadData
(
CUmodule
*
module
,
void
const
*
image
)
const
;
CUresult
cuLinkCreate
(
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
,
CUlinkState
*
stateOut
)
const
;
CUresult
cuModuleGetFunction
(
CUfunction
*
hfunc
,
CUmodule
hmod
,
char
const
*
name
)
const
;
CUresult
cuModuleGetGlobal
(
CUdeviceptr
*
dptr
,
size_t
*
bytes
,
CUmodule
hmod
,
char
const
*
name
)
const
;
CUresult
cuLinkAddFile
(
CUlinkState
state
,
CUjitInputType
type
,
char
const
*
path
,
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
)
const
;
CUresult
cuLinkAddData
(
CUlinkState
state
,
CUjitInputType
type
,
void
*
data
,
size_t
size
,
char
const
*
name
,
unsigned
int
numOptions
,
CUjit_option
*
options
,
void
**
optionValues
)
const
;
CUresult
cuLaunchCooperativeKernel
(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
)
const
;
CUresult
cuLaunchKernel
(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
,
void
**
extra
)
const
;
CUresult
cuTensorMapEncodeTiled
(
CUtensorMap
*
tensorMap
,
CUtensorMapDataType
tensorDataType
,
cuuint32_t
tensorRank
,
void
*
globalAddress
,
cuuint64_t
const
*
globalDim
,
cuuint64_t
const
*
globalStrides
,
cuuint32_t
const
*
boxDim
,
cuuint32_t
const
*
elementStrides
,
CUtensorMapInterleave
interleave
,
CUtensorMapSwizzle
swizzle
,
CUtensorMapL2promotion
l2Promotion
,
CUtensorMapFloatOOBfill
oobFill
)
const
;
CUresult
cuMemcpyDtoH
(
void
*
dstHost
,
CUdeviceptr
srcDevice
,
size_t
ByteCount
)
const
;
private:
void
*
handle
;
CUDADriverWrapper
();
CUresult
(
*
_cuGetErrorName
)(
CUresult
,
char
const
**
);
CUresult
(
*
_cuGetErrorMessage
)(
CUresult
,
char
const
**
);
CUresult
(
*
_cuFuncSetAttribute
)(
CUfunction
,
CUfunction_attribute
,
int
);
CUresult
(
*
_cuLinkComplete
)(
CUlinkState
,
void
**
,
size_t
*
);
CUresult
(
*
_cuModuleUnload
)(
CUmodule
);
CUresult
(
*
_cuLinkDestroy
)(
CUlinkState
);
CUresult
(
*
_cuLinkCreate
)(
unsigned
int
,
CUjit_option
*
,
void
**
,
CUlinkState
*
);
CUresult
(
*
_cuModuleLoadData
)(
CUmodule
*
,
void
const
*
);
CUresult
(
*
_cuModuleGetFunction
)(
CUfunction
*
,
CUmodule
,
char
const
*
);
CUresult
(
*
_cuModuleGetGlobal
)(
CUdeviceptr
*
,
size_t
*
,
CUmodule
,
char
const
*
);
CUresult
(
*
_cuLinkAddFile
)(
CUlinkState
,
CUjitInputType
,
char
const
*
,
unsigned
int
,
CUjit_option
*
,
void
**
);
CUresult
(
*
_cuLinkAddData
)(
CUlinkState
,
CUjitInputType
,
void
*
,
size_t
,
char
const
*
,
unsigned
int
,
CUjit_option
*
,
void
**
);
CUresult
(
*
_cuLaunchCooperativeKernel
)(
CUfunction
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
unsigned
int
,
CUstream
,
void
**
);
CUresult
(
*
_cuLaunchKernel
)(
CUfunction
f
,
unsigned
int
gridDimX
,
unsigned
int
gridDimY
,
unsigned
int
gridDimZ
,
unsigned
int
blockDimX
,
unsigned
int
blockDimY
,
unsigned
int
blockDimZ
,
unsigned
int
sharedMemBytes
,
CUstream
hStream
,
void
**
kernelParams
,
void
**
extra
);
CUresult
(
*
_cuTensorMapEncodeTiled
)(
CUtensorMap
*
tensorMap
,
CUtensorMapDataType
tensorDataType
,
cuuint32_t
tensorRank
,
void
*
globalAddress
,
cuuint64_t
const
*
globalDim
,
cuuint64_t
const
*
globalStrides
,
cuuint32_t
const
*
boxDim
,
cuuint32_t
const
*
elementStrides
,
CUtensorMapInterleave
interleave
,
CUtensorMapSwizzle
swizzle
,
CUtensorMapL2promotion
l2Promotion
,
CUtensorMapFloatOOBfill
oobFill
);
CUresult
(
*
_cuMemcpyDtoH
)(
void
*
dstHost
,
CUdeviceptr
srcDevice
,
size_t
ByteCount
);
};
template
<
typename
T
>
void
checkDriver
(
T
result
,
CUDADriverWrapper
const
&
wrap
,
char
const
*
const
func
,
char
const
*
const
file
,
int
const
line
)
{
if
(
result
)
{
char
const
*
errorName
=
nullptr
;
char
const
*
errorMsg
=
nullptr
;
wrap
.
cuGetErrorName
(
result
,
&
errorName
);
wrap
.
cuGetErrorMessage
(
result
,
&
errorMsg
);
throw
TllmException
(
file
,
line
,
fmtstr
(
"[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s"
,
func
,
errorName
,
errorMsg
));
}
}
}
// namespace tensorrt_llm::common
/*
* Macros compliant with TensorRT coding conventions
*/
#define TLLM_CU_CHECK(stat) \
do \
{ \
tensorrt_llm::common::checkDriver( \
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
} while (0)
#endif // CUDA_DRIVER_WRAPPER_H
sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h
0 → 100644
View file @
9602c2aa
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_FP8
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <stdint.h>
#define FP8_MHA
#define FUSE_GEMM_ACT
#define FP8_GEMM_OUTPUT_QUANT_DISABLE
#ifdef FUSE_GEMM_ACT
#define USE_QGMMA
#endif
namespace
tensorrt_llm
{
namespace
common
{
constexpr
float
FP8_E4M3_MAX
=
448.0
f
;
enum
QuantizeMode
{
PER_CHANNEL
,
PER_TENSOR
,
PER_CHANNEL_WEIGHT_PER_TENSOR_ACT
,
PER_TOKEN
,
};
// Packed Data Type
typedef
struct
__CUDA_ALIGN__
(
32
)
{
float
array
[
8
];
}
float8
;
typedef
struct
__CUDA_ALIGN__
(
16
)
{
half
array
[
8
];
}
half8
;
typedef
struct
__CUDA_ALIGN__
(
8
)
{
half2
array
[
2
];
}
half2_2
;
typedef
struct
__CUDA_ALIGN__
(
8
)
{
half
array
[
4
];
}
half_4
;
#ifdef ENABLE_BF16
typedef
struct
__CUDA_ALIGN__
(
4
)
{
__nv_bfloat16
array
[
2
];
}
__nv_bfloat16_2
;
typedef
struct
__CUDA_ALIGN__
(
8
)
{
__nv_bfloat162
x
,
y
;
}
__nv_bfloat162_2_xy
;
typedef
struct
__CUDA_ALIGN__
(
8
)
{
__nv_bfloat16
array
[
4
];
}
__nv_bfloat164
;
typedef
struct
__CUDA_ALIGN__
(
8
)
{
__nv_bfloat162
array
[
2
];
}
__nv_bfloat162_2
;
typedef
struct
__CUDA_ALIGN__
(
16
)
{
__nv_bfloat16
array
[
8
];
}
__nv_bfloat168
;
typedef
struct
__CUDA_ALIGN__
(
16
)
{
__nv_bfloat162
array
[
4
];
}
__nv_bfloat162_4
;
typedef
struct
__CUDA_ALIGN__
(
32
)
{
__nv_bfloat16
array
[
16
];
}
__nv_bfloat1616
;
#endif
#ifdef ENABLE_FP8
typedef
struct
__CUDA_ALIGN__
(
2
)
{
__nv_fp8_e4m3
array
[
2
];
}
__nv_fp8_2_e4m3
;
typedef
struct
__CUDA_ALIGN__
(
4
)
{
__nv_fp8_e4m3
array
[
4
];
}
__nv_fp8_4_e4m3
;
typedef
struct
__CUDA_ALIGN__
(
4
)
{
__nv_fp8x2_e4m3
array
[
2
];
}
__nv_fp8x2_x2_e4m3
;
typedef
struct
__CUDA_ALIGN__
(
8
)
{
__nv_fp8_e4m3
array
[
8
];
}
__nv_fp8_8_e4m3
;
typedef
struct
__CUDA_ALIGN__
(
8
)
{
__nv_fp8x2_e4m3
array
[
4
];
}
__nv_fp8x2_x4_e4m3
;
typedef
struct
__CUDA_ALIGN__
(
16
)
{
__nv_fp8_e4m3
array
[
16
];
}
__nv_fp8x16_e4m3
;
#endif
// only BF16 and FP8
template
<
typename
T
,
int
PACK_SIZE
>
struct
PackType
{
using
type
=
float
;
};
#ifdef ENABLE_BF16
template
<
>
struct
PackType
<
__nv_bfloat16
,
2
>
{
using
type
=
__nv_bfloat16_2
;
};
template
<
>
struct
PackType
<
__nv_bfloat16
,
4
>
{
using
type
=
__nv_bfloat164
;
};
template
<
>
struct
PackType
<
__nv_bfloat16
,
8
>
{
using
type
=
__nv_bfloat168
;
};
#endif
#ifdef ENABLE_FP8
template
<
>
struct
PackType
<
__nv_fp8_e4m3
,
2
>
{
using
type
=
__nv_fp8_2_e4m3
;
};
template
<
>
struct
PackType
<
__nv_fp8_e4m3
,
4
>
{
using
type
=
__nv_fp8_4_e4m3
;
};
template
<
>
struct
PackType
<
__nv_fp8_e4m3
,
8
>
{
using
type
=
__nv_fp8_8_e4m3
;
};
#endif
__inline__
__device__
void
fp8x4_e4m3_to_bfloat2
(
__nv_bfloat162
*
out1
,
__nv_bfloat162
*
out2
,
__nv_fp8x4_e4m3
const
*
in
)
{
const
char4
tmp_val
=
reinterpret_cast
<
char4
const
*>
(
in
)[
0
];
*
out1
=
__nv_bfloat162
((
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
x
)[
0
],
(
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
y
)[
0
]);
*
out2
=
__nv_bfloat162
((
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
z
)[
0
],
(
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
w
)[
0
]);
}
__inline__
__device__
__nv_bfloat162
fp8x2_e4m3_to_bfloat2
(
__nv_fp8x2_e4m3
const
*
in
)
{
const
char2
tmp_val
=
reinterpret_cast
<
char2
const
*>
(
in
)[
0
];
__nv_bfloat162
out
=
__nv_bfloat162
((
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
x
)[
0
],
(
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
y
)[
0
]);
return
out
;
}
__inline__
__device__
void
fp8x4_e4m3_to_half2
(
half2
*
out1
,
half2
*
out2
,
__nv_fp8x4_e4m3
const
*
in
)
{
const
char4
tmp_val
=
reinterpret_cast
<
char4
const
*>
(
in
)[
0
];
*
out1
=
half2
((
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
x
)[
0
],
(
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
y
)[
0
]);
*
out2
=
half2
((
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
z
)[
0
],
(
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
w
)[
0
]);
}
__inline__
__device__
half2
fp8x2_e4m3_to_half2
(
__nv_fp8x2_e4m3
const
*
in
)
{
const
char2
tmp_val
=
reinterpret_cast
<
char2
const
*>
(
in
)[
0
];
half2
out
=
half2
((
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
x
)[
0
],
(
float
)
reinterpret_cast
<
__nv_fp8_e4m3
const
*>
(
&
tmp_val
.
y
)[
0
]);
return
out
;
}
template
<
typename
T_OUT
,
typename
T_S
,
typename
T_IN
>
void
invokeQuantizeMatrix
(
T_OUT
*
output
,
T_S
const
*
input_qua_amax_ptr
,
T_IN
const
*
input
,
int64_t
numel
,
int64_t
lda
,
QuantizeMode
quantize_mode
,
cudaStream_t
stream
);
template
<
typename
T_OUT
,
typename
T_S
,
typename
T_IN
>
void
invokeDequantizeMatrix
(
T_OUT
*
output
,
T_S
const
*
input_qua_amax_ptr
,
T_IN
const
*
input
,
int64_t
numel
,
int64_t
lda
,
QuantizeMode
quantize_mode
,
cudaStream_t
stream
);
template
<
typename
T_FAKE
,
typename
T_OUT
,
typename
T_IN
>
void
invokeFakeQuantize
(
T_OUT
*
dst
,
const
T_IN
*
src
,
const
int64_t
numel
,
cudaStream_t
stream
);
template
<
typename
T_S
,
typename
T_W
>
void
invokeComputeFP8QuantizeScale
(
T_S
*
quant_ptr
,
const
T_W
*
weights
,
const
int64_t
k
,
const
int64_t
lda
,
QuantizeMode
quantize_mode
,
cudaStream_t
stream
);
template
<
typename
T_OUT
,
typename
T_S
,
typename
T_IN
>
void
invokeComputeScalesAndQuantizeMatrix
(
T_OUT
*
output
,
T_S
*
quant_ptr
,
const
T_IN
*
weights
,
const
int64_t
numel
,
const
int64_t
lda
,
QuantizeMode
quantize_mode
,
cudaStream_t
stream
);
}
// namespace common
}
// namespace tensorrt_llm
#endif // ENABLE_FP8
sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp
deleted
100644 → 0
View file @
e81d7f11
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaProfilerUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdint>
#include <optional>
namespace
{
std
::
tuple
<
std
::
unordered_set
<
int32_t
>
,
std
::
unordered_set
<
int32_t
>>
populateIterationIndexesImpl
(
std
::
string
const
&
envVarName
)
{
auto
envVarVal
=
std
::
getenv
(
envVarName
.
c_str
());
auto
envVarValStr
=
std
::
string
{
envVarVal
!=
nullptr
?
envVarVal
:
""
};
auto
values
=
tensorrt_llm
::
common
::
str2set
(
envVarValStr
,
','
);
std
::
unordered_set
<
int32_t
>
startSet
;
std
::
unordered_set
<
int32_t
>
endSet
;
for
(
std
::
string
const
&
value
:
values
)
{
size_t
dashIdx
=
value
.
find
(
"-"
);
if
(
dashIdx
!=
std
::
string
::
npos
)
{
int32_t
start
=
std
::
stoi
(
value
.
substr
(
0
,
dashIdx
));
startSet
.
insert
(
start
);
int32_t
end
=
std
::
stoi
(
value
.
substr
(
dashIdx
+
1
));
endSet
.
insert
(
end
);
}
else
{
int32_t
start_end
=
std
::
stoi
(
value
);
startSet
.
insert
(
start_end
);
endSet
.
insert
(
start_end
);
}
}
return
std
::
make_pair
(
startSet
,
endSet
);
}
}
// namespace
namespace
tensorrt_llm
::
common
{
std
::
pair
<
std
::
unordered_set
<
int32_t
>
,
std
::
unordered_set
<
int32_t
>>
populateIterationIndexes
(
std
::
string
const
&
envVarName
,
std
::
optional
<
std
::
string
>
const
&
legacyEnvVarName
)
{
auto
[
profileIterIdxs
,
stopIterIdxs
]
=
populateIterationIndexesImpl
(
envVarName
);
// If empty, try to use legacy env var name
if
(
legacyEnvVarName
&&
profileIterIdxs
.
empty
()
&&
stopIterIdxs
.
empty
())
{
std
::
tie
(
profileIterIdxs
,
stopIterIdxs
)
=
populateIterationIndexesImpl
(
legacyEnvVarName
.
value
());
if
(
!
profileIterIdxs
.
empty
()
||
!
stopIterIdxs
.
empty
())
{
TLLM_LOG_WARNING
(
"Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. "
"Please "
"use %s "
"instead."
,
legacyEnvVarName
.
value
().
c_str
(),
envVarName
.
c_str
());
}
}
return
std
::
make_pair
(
profileIterIdxs
,
stopIterIdxs
);
}
}
// namespace tensorrt_llm::common
sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h
0 → 100644
View file @
9602c2aa
This diff is collapsed.
Click to expand it.
sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h
deleted
100644 → 0
View file @
e81d7f11
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
namespace
tensorrt_llm
::
utils
::
customAllReduceUtils
{
constexpr
size_t
NUM_POINTERS_PER_RANK
=
7
;
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
inline
size_t
getMaxRequiredWorkspaceSize
(
int
worldSize
)
noexcept
{
if
(
worldSize
<=
2
)
{
return
16
*
1000
*
1000
;
}
return
8
*
1000
*
1000
;
}
}
// namespace tensorrt_llm::utils::customAllReduceUtils
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp
deleted
100644 → 0
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "envUtils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include <cstdlib>
namespace
tensorrt_llm
::
common
{
std
::
optional
<
int32_t
>
getIntEnv
(
char
const
*
name
)
{
char
const
*
const
env
=
std
::
getenv
(
name
);
if
(
env
==
nullptr
)
{
return
std
::
nullopt
;
}
int32_t
const
val
=
std
::
stoi
(
env
);
if
(
val
<=
0
)
{
return
std
::
nullopt
;
}
return
{
val
};
};
// Returns true if the env variable exists and is set to "1"
static
bool
getBoolEnv
(
char
const
*
name
)
{
char
const
*
env
=
std
::
getenv
(
name
);
return
env
&&
env
[
0
]
==
'1'
&&
env
[
1
]
==
'\0'
;
}
// XQA kernels (optimized kernels for generation phase).
bool
forceXQAKernels
()
{
static
bool
const
forceXQA
=
(
getIntEnv
(
"TRTLLM_FORCE_XQA"
).
value_or
(
0
)
!=
0
);
return
forceXQA
;
}
std
::
optional
<
bool
>
getEnvEnableXQAJIT
()
{
static
bool
init
=
false
;
static
bool
exists
=
false
;
static
bool
enableXQAJIT
=
false
;
if
(
!
init
)
{
init
=
true
;
char
const
*
enable_xqa_jit_var
=
std
::
getenv
(
"TRTLLM_ENABLE_XQA_JIT"
);
if
(
enable_xqa_jit_var
)
{
exists
=
true
;
if
(
enable_xqa_jit_var
[
0
]
==
'1'
&&
enable_xqa_jit_var
[
1
]
==
'\0'
)
{
enableXQAJIT
=
true
;
}
}
}
if
(
exists
)
{
return
enableXQAJIT
;
}
else
{
return
std
::
nullopt
;
}
}
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool
getEnvMmhaMultiblockDebug
()
{
static
bool
init
=
false
;
static
bool
forceMmhaMaxSeqLenTile
=
false
;
if
(
!
init
)
{
init
=
true
;
char
const
*
enable_mmha_debug_var
=
std
::
getenv
(
"TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"
);
if
(
enable_mmha_debug_var
)
{
if
(
enable_mmha_debug_var
[
0
]
==
'1'
&&
enable_mmha_debug_var
[
1
]
==
'\0'
)
{
forceMmhaMaxSeqLenTile
=
true
;
}
}
}
return
forceMmhaMaxSeqLenTile
;
}
int
getEnvMmhaBlocksPerSequence
()
{
static
bool
init
=
false
;
static
int
mmhaBlocksPerSequence
=
0
;
if
(
!
init
)
{
init
=
true
;
char
const
*
mmhaBlocksPerSequenceEnv
=
std
::
getenv
(
"TRTLLM_MMHA_BLOCKS_PER_SEQUENCE"
);
if
(
mmhaBlocksPerSequenceEnv
)
{
mmhaBlocksPerSequence
=
std
::
atoi
(
mmhaBlocksPerSequenceEnv
);
if
(
mmhaBlocksPerSequence
<=
0
)
{
TLLM_LOG_WARNING
(
"Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!"
);
}
}
}
return
mmhaBlocksPerSequence
;
}
int
getEnvMmhaKernelBlockSize
()
{
static
bool
init
=
false
;
static
int
mmhaKernelBlockSize
=
0
;
if
(
!
init
)
{
init
=
true
;
char
const
*
mmhaKernelBlockSizeEnv
=
std
::
getenv
(
"TRTLLM_MMHA_KERNEL_BLOCK_SIZE"
);
if
(
mmhaKernelBlockSizeEnv
)
{
mmhaKernelBlockSize
=
std
::
atoi
(
mmhaKernelBlockSizeEnv
);
if
(
mmhaKernelBlockSize
<=
0
)
{
TLLM_LOG_WARNING
(
"Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!"
);
}
}
}
return
mmhaKernelBlockSize
;
}
bool
getEnvEnablePDL
()
{
static
bool
init
=
false
;
static
bool
enablePDL
=
false
;
if
(
!
init
)
{
init
=
true
;
// PDL only available when arch >= 90
if
(
getSMVersion
()
>=
90
)
{
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
enablePDL
=
getBoolEnv
(
"TRTLLM_ENABLE_PDL"
);
}
}
return
enablePDL
;
}
bool
getEnvUseUCXKvCache
()
{
static
bool
const
useUCXKVCache
=
getBoolEnv
(
"TRTLLM_USE_UCX_KVCACHE"
);
return
useUCXKVCache
;
}
std
::
string
getEnvUCXInterface
()
{
static
bool
init
=
false
;
static
std
::
string
ucxInterface
;
if
(
!
init
)
{
init
=
true
;
{
char
const
*
ucx_interface
=
std
::
getenv
(
"TRTLLM_UCX_INTERFACE"
);
if
(
ucx_interface
)
{
ucxInterface
=
ucx_interface
;
}
}
}
return
ucxInterface
;
}
bool
getEnvDisaggLayerwise
()
{
static
bool
const
disaggLayerwise
=
getBoolEnv
(
"TRTLLM_DISAGG_LAYERWISE"
);
return
disaggLayerwise
;
}
bool
getEnvParallelCacheSend
()
{
static
bool
const
parallelCacheSend
=
getBoolEnv
(
"TRTLLM_PARALLEL_CACHE_SEND"
);
return
parallelCacheSend
;
}
bool
getEnvRequestKVCacheSerial
()
{
static
bool
const
requestKVCacheSerial
=
getBoolEnv
(
"TRTLLM_REQUEST_KV_CACHE_SERIAL"
);
return
requestKVCacheSerial
;
}
bool
getEnvDisableKVCacheTransferOverlap
()
{
static
bool
const
disableKVCacheTransferOverlap
=
getBoolEnv
(
"TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"
);
return
disableKVCacheTransferOverlap
;
}
bool
getEnvDisableReceiveKVCacheParallel
()
{
static
bool
const
disableReceiveParallel
=
getBoolEnv
(
"TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL"
);
return
disableReceiveParallel
;
}
}
// namespace tensorrt_llm::common
sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h
deleted
100644 → 0
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <optional>
#include <string>
namespace
tensorrt_llm
::
common
{
// Useful when you want to inject some debug code controllable with env var.
std
::
optional
<
int32_t
>
getIntEnv
(
char
const
*
name
);
// XQA kernels (optimized kernels for generation phase).
bool
forceXQAKernels
();
// Whether XQA JIT is enabled.
//
// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned.
std
::
optional
<
bool
>
getEnvEnableXQAJIT
();
// Tune the number of blocks per sequence for accuracy/performance purpose.
bool
getEnvMmhaMultiblockDebug
();
int
getEnvMmhaBlocksPerSequence
();
int
getEnvMmhaKernelBlockSize
();
// Whether PDL is enabled.
bool
getEnvEnablePDL
();
bool
getEnvUseUCXKvCache
();
std
::
string
getEnvUCXInterface
();
bool
getEnvDisaggLayerwise
();
bool
getEnvParallelCacheSend
();
bool
getEnvRequestKVCacheSerial
();
bool
getEnvDisableKVCacheTransferOverlap
();
bool
getEnvDisableReceiveKVCacheParallel
();
}
// namespace tensorrt_llm::common
sgl-kernel/3rdparty/tensorrt_llm/common/logger.h
0 → 100644
View file @
9602c2aa
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdlib>
#include <iostream>
#include <stdexcept>
#include <string>
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/stringUtils.h"
namespace
tensorrt_llm
::
common
{
class
Logger
{
// On Windows, the file wingdi.h is included which has
// #define ERROR 0
// This breaks everywhere ERROR is used in the Level enum
#ifdef _WIN32
#undef ERROR
#endif // _WIN32
public:
enum
Level
{
TRACE
=
0
,
DEBUG
=
10
,
INFO
=
20
,
WARNING
=
30
,
ERROR
=
40
};
static
Logger
*
getLogger
();
Logger
(
Logger
const
&
)
=
delete
;
void
operator
=
(
Logger
const
&
)
=
delete
;
#if defined(_MSC_VER)
template
<
typename
...
Args
>
void
log
(
Level
level
,
char
const
*
format
,
Args
const
&
...
args
);
template
<
typename
...
Args
>
void
log
(
Level
level
,
int
rank
,
char
const
*
format
,
Args
const
&
...
args
);
#else
template
<
typename
...
Args
>
void
log
(
Level
level
,
char
const
*
format
,
Args
const
&
...
args
)
__attribute__
((
format
(
printf
,
3
,
0
)));
template
<
typename
...
Args
>
void
log
(
Level
level
,
int
rank
,
char
const
*
format
,
Args
const
&
...
args
)
__attribute__
((
format
(
printf
,
4
,
0
)));
#endif
template
<
typename
...
Args
>
void
log
(
Level
level
,
std
::
string
const
&
format
,
Args
const
&
...
args
)
{
return
log
(
level
,
format
.
c_str
(),
args
...);
}
template
<
typename
...
Args
>
void
log
(
Level
const
level
,
int
const
rank
,
std
::
string
const
&
format
,
Args
const
&
...
args
)
{
return
log
(
level
,
rank
,
format
.
c_str
(),
args
...);
}
void
log
(
std
::
exception
const
&
ex
,
Level
level
=
Level
::
ERROR
);
Level
getLevel
()
const
{
return
level_
;
}
void
setLevel
(
Level
const
level
)
{
level_
=
level
;
log
(
INFO
,
"Set logger level to %s"
,
getLevelName
(
level
));
}
bool
isEnabled
(
Level
const
level
)
const
{
return
level_
<=
level
;
}
private:
static
auto
constexpr
kPREFIX
=
"[TensorRT-LLM]"
;
#ifndef NDEBUG
Level
const
DEFAULT_LOG_LEVEL
=
DEBUG
;
#else
Level
const
DEFAULT_LOG_LEVEL
=
INFO
;
#endif
Level
level_
=
DEFAULT_LOG_LEVEL
;
Logger
();
// NOLINT(modernize-use-equals-delete)
static
inline
char
const
*
getLevelName
(
Level
const
level
)
{
switch
(
level
)
{
case
TRACE
:
return
"TRACE"
;
case
DEBUG
:
return
"DEBUG"
;
case
INFO
:
return
"INFO"
;
case
WARNING
:
return
"WARNING"
;
case
ERROR
:
return
"ERROR"
;
}
TLLM_THROW
(
"Unknown log level: %d"
,
level
);
}
static
inline
std
::
string
getPrefix
(
Level
const
level
)
{
return
fmtstr
(
"%s[%s] "
,
kPREFIX
,
getLevelName
(
level
));
}
static
inline
std
::
string
getPrefix
(
Level
const
level
,
int
const
rank
)
{
return
fmtstr
(
"%s[%s][%d] "
,
kPREFIX
,
getLevelName
(
level
),
rank
);
}
};
template
<
typename
...
Args
>
void
Logger
::
log
(
Logger
::
Level
level
,
char
const
*
format
,
Args
const
&
...
args
)
{
if
(
isEnabled
(
level
))
{
auto
const
fmt
=
getPrefix
(
level
)
+
format
;
auto
&
out
=
level_
<
WARNING
?
std
::
cout
:
std
::
cerr
;
if
constexpr
(
sizeof
...(
args
)
>
0
)
{
out
<<
fmtstr
(
fmt
.
c_str
(),
args
...);
}
else
{
out
<<
fmt
;
}
out
<<
std
::
endl
;
}
}
template
<
typename
...
Args
>
void
Logger
::
log
(
Logger
::
Level
const
level
,
int
const
rank
,
char
const
*
format
,
Args
const
&
...
args
)
{
if
(
isEnabled
(
level
))
{
auto
const
fmt
=
getPrefix
(
level
,
rank
)
+
format
;
auto
&
out
=
level_
<
WARNING
?
std
::
cout
:
std
::
cerr
;
if
constexpr
(
sizeof
...(
args
)
>
0
)
{
out
<<
fmtstr
(
fmt
.
c_str
(),
args
...);
}
else
{
out
<<
fmt
;
}
out
<<
std
::
endl
;
}
}
#define TLLM_LOG(level, ...) \
do \
{ \
auto* const logger = tensorrt_llm::common::Logger::getLogger(); \
if (logger->isEnabled(level)) \
{ \
logger->log(level, __VA_ARGS__); \
} \
} while (0)
#define TLLM_LOG_TRACE(...) TLLM_LOG(tensorrt_llm::common::Logger::TRACE, __VA_ARGS__)
#define TLLM_LOG_DEBUG(...) TLLM_LOG(tensorrt_llm::common::Logger::DEBUG, __VA_ARGS__)
#define TLLM_LOG_INFO(...) TLLM_LOG(tensorrt_llm::common::Logger::INFO, __VA_ARGS__)
#define TLLM_LOG_WARNING(...) TLLM_LOG(tensorrt_llm::common::Logger::WARNING, __VA_ARGS__)
#define TLLM_LOG_ERROR(...) TLLM_LOG(tensorrt_llm::common::Logger::ERROR, __VA_ARGS__)
#define TLLM_LOG_EXCEPTION(ex, ...) tensorrt_llm::common::Logger::getLogger()->log(ex, ##__VA_ARGS__)
}
// namespace tensorrt_llm::common
sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h
deleted
100644 → 0
View file @
e81d7f11
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime.h>
namespace
tensorrt_llm
{
namespace
common
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
__host__
T
divUp
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace common
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu
deleted
100644 → 0
View file @
e81d7f11
This diff is collapsed.
Click to expand it.
sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h
deleted
100644 → 0
View file @
e81d7f11
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include <cassert>
namespace
tensorrt_llm
{
namespace
common
{
template
<
typename
T
>
void
deviceMalloc
(
T
**
ptr
,
size_t
size
,
bool
is_random_initialize
=
true
);
template
<
typename
T
>
void
deviceMemSetZero
(
T
*
ptr
,
size_t
size
);
template
<
typename
T
>
void
deviceFree
(
T
*&
ptr
);
template
<
typename
T
>
void
deviceFill
(
T
*
devptr
,
size_t
size
,
T
value
,
cudaStream_t
stream
=
0
);
template
<
typename
T
>
void
cudaD2Hcpy
(
T
*
tgt
,
T
const
*
src
,
size_t
const
size
);
template
<
typename
T
>
void
cudaH2Dcpy
(
T
*
tgt
,
T
const
*
src
,
size_t
const
size
);
template
<
typename
T
>
void
cudaD2Dcpy
(
T
*
tgt
,
T
const
*
src
,
size_t
const
size
,
cudaStream_t
stream
=
NULL
);
template
<
typename
T
>
void
cudaAutoCpy
(
T
*
tgt
,
T
const
*
src
,
size_t
const
size
,
cudaStream_t
stream
=
NULL
);
template
<
typename
T
>
void
cudaRandomUniform
(
T
*
buffer
,
size_t
const
size
);
template
<
typename
T
>
int
loadWeightFromBin
(
T
*
ptr
,
std
::
vector
<
size_t
>
shape
,
std
::
string
filename
,
TRTLLMCudaDataType
model_file_type
=
TRTLLMCudaDataType
::
FP32
);
// template<typename T>
// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr,
// T* scale_ptr,
// std::vector<size_t> shape,
// std::string filename,
// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32);
void
invokeCudaD2DcpyHalf2Float
(
float
*
dst
,
half
*
src
,
size_t
const
size
,
cudaStream_t
stream
);
void
invokeCudaD2DcpyFloat2Half
(
half
*
dst
,
float
*
src
,
size_t
const
size
,
cudaStream_t
stream
);
#ifdef ENABLE_FP8
void
invokeCudaD2Dcpyfp82Float
(
float
*
dst
,
__nv_fp8_e4m3
*
src
,
size_t
const
size
,
cudaStream_t
stream
);
void
invokeCudaD2Dcpyfp82Half
(
half
*
dst
,
__nv_fp8_e4m3
*
src
,
size_t
const
size
,
cudaStream_t
stream
);
void
invokeCudaD2DcpyFloat2fp8
(
__nv_fp8_e4m3
*
dst
,
float
*
src
,
size_t
const
size
,
cudaStream_t
stream
);
void
invokeCudaD2DcpyHalf2fp8
(
__nv_fp8_e4m3
*
dst
,
half
*
src
,
size_t
const
size
,
cudaStream_t
stream
);
void
invokeCudaD2DcpyBfloat2fp8
(
__nv_fp8_e4m3
*
dst
,
__nv_bfloat16
*
src
,
size_t
const
size
,
cudaStream_t
stream
);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
void
invokeCudaD2DcpyBfloat2Float
(
float
*
dst
,
__nv_bfloat16
*
src
,
size_t
const
size
,
cudaStream_t
stream
);
#endif // ENABLE_BF16
template
<
typename
T_OUT
,
typename
T_IN
>
void
invokeCudaCast
(
T_OUT
*
dst
,
T_IN
const
*
const
src
,
size_t
const
size
,
cudaStream_t
stream
);
////////////////////////////////////////////////////////////////////////////////////////////////////
// The following functions implement conversion of multi-dimensional indices to an index in a flat array.
// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments.
// For examples on how to use these functions, see their tests `test_memory_utils.cu`.
// All of these functions can be evaluated at compile time by recursive template expansion.
template
<
typename
TDim
,
typename
T
,
typename
TIndex
>
__inline__
__host__
__device__
std
::
enable_if_t
<
std
::
is_pointer
<
TDim
>::
value
,
T
>
constexpr
flat_index
(
T
const
&
acc
,
TDim
dims
,
TIndex
const
&
index
)
{
assert
(
index
<
dims
[
0
]);
return
acc
*
dims
[
0
]
+
index
;
}
template
<
typename
TDim
,
typename
T
,
typename
TIndex
,
typename
...
TIndices
>
__inline__
__host__
__device__
std
::
enable_if_t
<
std
::
is_pointer
<
TDim
>::
value
,
T
>
constexpr
flat_index
(
T
const
&
acc
,
TDim
dims
,
TIndex
const
&
index
,
TIndices
...
indices
)
{
assert
(
index
<
dims
[
0
]);
return
flat_index
(
acc
*
dims
[
0
]
+
index
,
dims
+
1
,
indices
...);
}
template
<
typename
TDim
,
typename
T
>
__inline__
__host__
__device__
std
::
enable_if_t
<
std
::
is_pointer
<
TDim
>::
value
,
T
>
constexpr
flat_index
(
[[
maybe_unused
]]
TDim
dims
,
T
const
&
index
)
{
assert
(
index
<
dims
[
0
]);
return
index
;
}
template
<
typename
TDim
,
typename
TIndex
,
typename
...
TIndices
>
__inline__
__host__
__device__
std
::
enable_if_t
<
std
::
is_pointer
<
TDim
>::
value
,
typename
std
::
remove_pointer
<
TDim
>::
type
>
constexpr
flat_index
(
TDim
dims
,
TIndex
const
&
index
,
TIndices
...
indices
)
{
assert
(
index
<
dims
[
0
]);
return
flat_index
(
static_cast
<
typename
std
::
remove_pointer
<
TDim
>::
type
>
(
index
),
dims
+
1
,
indices
...);
}
template
<
unsigned
skip
=
0
,
typename
T
,
std
::
size_t
N
,
typename
TIndex
,
typename
...
TIndices
>
__inline__
__host__
__device__
T
constexpr
flat_index
(
std
::
array
<
T
,
N
>
const
&
dims
,
TIndex
const
&
index
,
TIndices
...
indices
)
{
static_assert
(
skip
<
N
);
static_assert
(
sizeof
...(
TIndices
)
<
N
-
skip
,
"Number of indices exceeds number of dimensions"
);
return
flat_index
(
&
dims
[
skip
],
index
,
indices
...);
}
template
<
unsigned
skip
=
0
,
typename
T
,
typename
TIndex
,
std
::
size_t
N
,
typename
...
TIndices
>
__inline__
__host__
__device__
T
constexpr
flat_index
(
T
const
&
acc
,
std
::
array
<
T
,
N
>
const
&
dims
,
TIndex
const
&
index
,
TIndices
...
indices
)
{
static_assert
(
skip
<
N
);
static_assert
(
sizeof
...(
TIndices
)
<
N
-
skip
,
"Number of indices exceeds number of dimensions"
);
return
flat_index
(
acc
,
&
dims
[
skip
],
index
,
indices
...);
}
template
<
unsigned
skip
=
0
,
typename
T
,
typename
TIndex
,
std
::
size_t
N
,
typename
...
TIndices
>
__inline__
__host__
__device__
T
constexpr
flat_index
(
T
const
(
&
dims
)[
N
],
TIndex
const
&
index
,
TIndices
...
indices
)
{
static_assert
(
skip
<
N
);
static_assert
(
sizeof
...(
TIndices
)
<
N
-
skip
,
"Number of indices exceeds number of dimensions"
);
return
flat_index
(
static_cast
<
T
const
*>
(
dims
)
+
skip
,
index
,
indices
...);
}
template
<
unsigned
skip
=
0
,
typename
T
,
typename
TIndex
,
std
::
size_t
N
,
typename
...
TIndices
>
__inline__
__host__
__device__
T
constexpr
flat_index
(
T
const
&
acc
,
T
const
(
&
dims
)[
N
],
TIndex
const
&
index
,
TIndices
...
indices
)
{
static_assert
(
skip
<
N
);
static_assert
(
sizeof
...(
TIndices
)
<
N
-
skip
,
"Number of indices exceeds number of dimensions"
);
return
flat_index
(
acc
,
static_cast
<
T
const
*>
(
dims
)
+
skip
,
index
,
indices
...);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual
// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions
// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be
// evaluated at compile time.
template
<
typename
T
,
typename
TIndex
>
__inline__
__host__
__device__
T
constexpr
flat_index2
(
TIndex
const
&
index_0
,
TIndex
const
&
index_1
,
T
const
&
dim_1
)
{
assert
(
index_1
<
dim_1
);
return
index_0
*
dim_1
+
index_1
;
}
template
<
typename
T
,
typename
TIndex
>
__inline__
__host__
__device__
T
constexpr
flat_index3
(
TIndex
const
&
index_0
,
TIndex
const
&
index_1
,
TIndex
const
&
index_2
,
T
const
&
dim_1
,
T
const
&
dim_2
)
{
assert
(
index_2
<
dim_2
);
return
flat_index2
(
index_0
,
index_1
,
dim_1
)
*
dim_2
+
index_2
;
}
template
<
typename
T
,
typename
TIndex
>
__inline__
__host__
__device__
T
constexpr
flat_index4
(
TIndex
const
&
index_0
,
TIndex
const
&
index_1
,
TIndex
const
&
index_2
,
TIndex
const
&
index_3
,
T
const
&
dim_1
,
T
const
&
dim_2
,
T
const
&
dim_3
)
{
assert
(
index_3
<
dim_3
);
return
flat_index3
(
index_0
,
index_1
,
index_2
,
dim_1
,
dim_2
)
*
dim_3
+
index_3
;
}
template
<
typename
T
,
typename
TIndex
>
__inline__
__host__
__device__
T
constexpr
flat_index5
(
TIndex
const
&
index_0
,
TIndex
const
&
index_1
,
TIndex
const
&
index_2
,
TIndex
const
&
index_3
,
TIndex
const
&
index_4
,
T
const
&
dim_1
,
T
const
&
dim_2
,
T
const
&
dim_3
,
T
const
&
dim_4
)
{
assert
(
index_4
<
dim_4
);
return
flat_index4
(
index_0
,
index_1
,
index_2
,
index_3
,
dim_1
,
dim_2
,
dim_3
)
*
dim_4
+
index_4
;
}
template
<
typename
T
,
typename
TIndex
>
__inline__
__host__
__device__
T
constexpr
flat_index_strided3
(
TIndex
const
&
index_0
,
TIndex
const
&
index_1
,
TIndex
const
&
index_2
,
T
const
&
stride_1
,
T
const
&
stride_2
)
{
assert
(
index_1
<
stride_1
/
stride_2
);
assert
(
index_2
<
stride_2
);
return
index_0
*
stride_1
+
index_1
*
stride_2
+
index_2
;
}
template
<
typename
T
,
typename
TIndex
>
__inline__
__host__
__device__
T
constexpr
flat_index_strided4
(
TIndex
const
&
index_0
,
TIndex
const
&
index_1
,
TIndex
const
&
index_2
,
TIndex
const
&
index_3
,
T
const
&
stride_1
,
T
const
&
stride_2
,
T
const
&
stride_3
)
{
assert
(
index_1
<
stride_1
/
stride_2
);
assert
(
index_2
<
stride_2
/
stride_3
);
assert
(
index_3
<
stride_3
);
return
index_0
*
stride_1
+
index_1
*
stride_2
+
index_2
*
stride_3
+
index_3
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
void
invokeInPlaceTranspose
(
T
*
data
,
T
*
workspace
,
size_t
const
dim0
,
size_t
const
dim1
);
template
<
typename
T
>
void
invokeInPlaceTranspose0213
(
T
*
data
,
T
*
workspace
,
size_t
const
dim0
,
size_t
const
dim1
,
size_t
const
dim2
,
size_t
const
dim3
);
template
<
typename
T
>
void
invokeInPlaceTranspose102
(
T
*
data
,
T
*
workspace
,
size_t
const
dim0
,
size_t
const
dim1
,
size_t
const
dim2
);
template
<
typename
T
>
void
invokeMultiplyScale
(
T
*
tensor
,
float
scale
,
size_t
const
size
,
cudaStream_t
stream
);
template
<
typename
T
>
void
invokeDivideScale
(
T
*
tensor
,
float
scale
,
size_t
const
size
,
cudaStream_t
stream
);
template
<
typename
T_IN
,
typename
T_OUT
>
void
invokeCudaD2DcpyConvert
(
T_OUT
*
tgt
,
const
T_IN
*
src
,
size_t
const
size
,
cudaStream_t
stream
=
0
);
template
<
typename
T_IN
,
typename
T_OUT
>
void
invokeCudaD2DScaleCpyConvert
(
T_OUT
*
tgt
,
const
T_IN
*
src
,
float
const
*
scale
,
bool
invert_scale
,
size_t
const
size
,
cudaStream_t
stream
=
0
);
inline
bool
checkIfFileExist
(
std
::
string
const
&
file_path
)
{
std
::
ifstream
in
(
file_path
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
if
(
in
.
is_open
())
{
in
.
close
();
return
true
;
}
return
false
;
}
template
<
typename
T
>
void
saveToBinary
(
T
const
*
ptr
,
size_t
const
size
,
std
::
string
filename
);
template
<
typename
T_IN
,
typename
T_fake_type
>
void
invokeFakeCast
(
T_IN
*
input_ptr
,
size_t
const
size
,
cudaStream_t
stream
);
size_t
cuda_datatype_size
(
TRTLLMCudaDataType
dt
);
template
<
typename
T
>
bool
invokeCheckRange
(
T
const
*
buffer
,
size_t
const
size
,
T
min
,
T
max
,
bool
*
d_within_range
,
cudaStream_t
stream
);
constexpr
size_t
DEFAULT_ALIGN_BYTES
=
256
;
size_t
calcAlignedSize
(
std
::
vector
<
size_t
>
const
&
sizes
,
size_t
ALIGN_BYTES
=
DEFAULT_ALIGN_BYTES
);
void
calcAlignedPointers
(
std
::
vector
<
void
*>&
outPtrs
,
void
const
*
p
,
std
::
vector
<
size_t
>
const
&
sizes
,
size_t
ALIGN_BYTES
=
DEFAULT_ALIGN_BYTES
);
struct
AlignedPointersUnpacker
{
template
<
typename
...
T
>
void
operator
()(
T
*&
...
outPtrs
)
{
assert
(
sizeof
...(
T
)
==
alignedPointers
.
size
());
auto
it
=
alignedPointers
.
begin
();
((
outPtrs
=
static_cast
<
T
*>
(
*
it
++
)),
...);
}
std
::
vector
<
void
*>
alignedPointers
;
};
AlignedPointersUnpacker
inline
calcAlignedPointers
(
void
const
*
p
,
std
::
vector
<
size_t
>
const
&
sizes
,
size_t
ALIGN_BYTES
=
DEFAULT_ALIGN_BYTES
)
{
AlignedPointersUnpacker
unpacker
{};
calcAlignedPointers
(
unpacker
.
alignedPointers
,
p
,
sizes
,
ALIGN_BYTES
);
return
unpacker
;
}
}
// namespace common
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp
deleted
100644 → 0
View file @
e81d7f11
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <numeric>
#include <unordered_set>
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <csignal>
#include <cstdlib>
#include <mutex>
#include <thread>
#include <type_traits>
#ifndef _WIN32
#include <unistd.h>
#endif
// We rely on SizeType32 being int32_t in some places with weak type checking,
// i.e. we're passing void ptr to some function. To prevent mysterious errors
// in the future, we trigger a compilation error here if SizeType32 isn't int32_t.
static_assert
(
std
::
is_same
<
tensorrt_llm
::
runtime
::
SizeType32
,
std
::
int32_t
>::
value
);
namespace
tensorrt_llm
::
mpi
{
MPI_Datatype
getMpiDtype
(
MpiType
dtype
)
{
#if ENABLE_MULTI_DEVICE
static
std
::
unordered_map
<
MpiType
,
MPI_Datatype
>
const
dtype_map
{
{
MpiType
::
kBYTE
,
MPI_BYTE
},
{
MpiType
::
kHALF
,
MPI_UINT16_T
},
{
MpiType
::
kFLOAT
,
MPI_FLOAT
},
{
MpiType
::
kDOUBLE
,
MPI_DOUBLE
},
{
MpiType
::
kBOOL
,
MPI_C_BOOL
},
{
MpiType
::
kINT8
,
MPI_INT8_T
},
{
MpiType
::
kUINT8
,
MPI_UINT8_T
},
{
MpiType
::
kINT32
,
MPI_INT32_T
},
{
MpiType
::
kUINT32
,
MPI_UINT32_T
},
{
MpiType
::
kINT64
,
MPI_INT64_T
},
{
MpiType
::
kUINT64
,
MPI_UINT64_T
},
{
MpiType
::
kFP8
,
MPI_UINT8_T
},
{
MpiType
::
kBF16
,
MPI_UINT16_T
},
{
MpiType
::
kCHAR
,
MPI_CHAR
},
};
return
dtype_map
.
at
(
dtype
);
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif
}
MPI_Op
getMpiOp
(
MpiOp
op
)
{
#if ENABLE_MULTI_DEVICE
static
std
::
unordered_map
<
MpiOp
,
MPI_Op
>
const
op_map
{
{
MpiOp
::
NULLOP
,
MPI_OP_NULL
},
{
MpiOp
::
MAX
,
MPI_MAX
},
{
MpiOp
::
MIN
,
MPI_MIN
},
{
MpiOp
::
SUM
,
MPI_SUM
},
{
MpiOp
::
PROD
,
MPI_PROD
},
{
MpiOp
::
LAND
,
MPI_LAND
},
{
MpiOp
::
BAND
,
MPI_BAND
},
{
MpiOp
::
LOR
,
MPI_LOR
},
{
MpiOp
::
BOR
,
MPI_BOR
},
{
MpiOp
::
LXOR
,
MPI_LXOR
},
{
MpiOp
::
BXOR
,
MPI_BXOR
},
{
MpiOp
::
MINLOC
,
MPI_MINLOC
},
{
MpiOp
::
MAXLOC
,
MPI_MAXLOC
},
{
MpiOp
::
REPLACE
,
MPI_REPLACE
},
};
return
op_map
.
at
(
op
);
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
}
namespace
{
bool
mpiInitialized
=
false
;
std
::
recursive_mutex
mpiMutex
;
MpiComm
initLocalSession
()
{
#if ENABLE_MULTI_DEVICE
MPI_Comm
localComm
=
nullptr
;
MPI_Comm_split_type
(
COMM_SESSION
,
OMPI_COMM_TYPE_HOST
,
COMM_SESSION
.
getRank
(),
MPI_INFO_NULL
,
&
localComm
);
MpiComm
localSession
{
localComm
,
false
};
#else
MpiComm
localSession
{
COMM_SESSION
,
false
};
#endif // ENABLE_MULTI_DEVICE
return
localSession
;
}
}
// namespace
std
::
vector
<
int
>
getWorldRanks
(
MpiComm
const
&
comm
)
{
#if ENABLE_MULTI_DEVICE
MPI_Group
group
=
nullptr
;
MPI_Group
worldGroup
=
nullptr
;
MPICHECK
(
MPI_Comm_group
(
MPI_COMM_WORLD
,
&
worldGroup
));
MPICHECK
(
MPI_Comm_group
(
comm
,
&
group
));
int
groupSize
=
0
;
MPICHECK
(
MPI_Group_size
(
group
,
&
groupSize
));
std
::
vector
<
int
>
ranks
(
groupSize
);
std
::
vector
<
int
>
worldRanks
(
groupSize
);
std
::
iota
(
ranks
.
begin
(),
ranks
.
end
(),
0
);
MPICHECK
(
MPI_Group_translate_ranks
(
group
,
groupSize
,
ranks
.
data
(),
worldGroup
,
worldRanks
.
data
()));
MPICHECK
(
MPI_Group_free
(
&
group
));
MPICHECK
(
MPI_Group_free
(
&
worldGroup
));
#else
std
::
vector
<
int
>
worldRanks
{
0
};
#endif
return
worldRanks
;
}
void
initialize
(
MpiThreadSupport
threadMode
,
bool
forwardAbortToParent
)
{
// double-checked locking
if
(
mpiInitialized
)
{
return
;
}
std
::
lock_guard
<
std
::
recursive_mutex
>
lk
(
mpiMutex
);
if
(
mpiInitialized
)
{
return
;
}
#if ENABLE_MULTI_DEVICE
int
initialized
=
0
;
TLLM_MPI_CHECK
(
MPI_Initialized
(
&
initialized
));
if
(
!
initialized
)
{
TLLM_LOG_INFO
(
"Initializing MPI with thread mode %d"
,
threadMode
);
int
providedMode
=
0
;
auto
requiredMode
=
static_cast
<
int
>
(
threadMode
);
MPICHECK
(
MPI_Init_thread
(
nullptr
,
nullptr
,
requiredMode
,
&
providedMode
));
TLLM_CHECK_WITH_INFO
(
providedMode
>=
requiredMode
,
"MPI_Init_thread failed"
);
std
::
atexit
([]()
{
MPI_Finalize
();
});
/*
* We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2
* signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers
* correctly.
*/
for
(
int
sig
:
{
SIGABRT
,
SIGSEGV
})
{
__sighandler_t
previousHandler
=
nullptr
;
if
(
forwardAbortToParent
)
{
previousHandler
=
std
::
signal
(
sig
,
[](
int
signal
)
{
#ifndef _WIN32
pid_t
parentProcessId
=
getppid
();
kill
(
parentProcessId
,
SIGKILL
);
#endif
MPI_Abort
(
MPI_COMM_WORLD
,
EXIT_FAILURE
);
});
}
else
{
previousHandler
=
std
::
signal
(
sig
,
[](
int
signal
)
{
MPI_Abort
(
MPI_COMM_WORLD
,
EXIT_FAILURE
);
});
}
TLLM_CHECK_WITH_INFO
(
previousHandler
!=
SIG_ERR
,
"Signal handler setup failed"
);
}
// ensure local MPI communicator is initialized
MpiComm
::
localSession
();
TLLM_LOG_INFO
(
"Initialized MPI"
);
}
#endif // ENABLE_MULTI_DEVICE
mpiInitialized
=
true
;
}
void
MpiComm
::
barrier
()
const
{
#if ENABLE_MULTI_DEVICE
MPICHECK
(
MPI_Barrier
(
mComm
));
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
}
#if ENABLE_MULTI_DEVICE
template
<
typename
TMpiFunc
,
typename
TBase
,
typename
...
TArgs
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
void
,
std
::
remove_const_t
<
TBase
>
>>>
size_t
invokeChunked
(
TMpiFunc
func
,
TBase
*
buffer
,
size_t
size
,
MPI_Datatype
dtype
,
TArgs
...
args
)
{
constexpr
auto
maxP1
=
static_cast
<
size_t
>
(
std
::
numeric_limits
<
int
>::
max
())
+
1
;
if
(
TLLM_LIKELY
(
size
<
maxP1
))
{
MPICHECK
(
func
(
buffer
,
size
,
dtype
,
args
...));
return
1
;
}
constexpr
size_t
alignment
=
256
;
int
elementSize
=
1
;
MPICHECK
(
MPI_Type_size
(
dtype
,
&
elementSize
));
elementSize
=
std
::
min
<
int
>
(
elementSize
,
alignment
);
// We cap at max alignment-bytes chunks that can be sent at once.
auto
const
step
=
maxP1
-
(
alignment
/
elementSize
);
using
TCast
=
std
::
conditional_t
<
std
::
is_const_v
<
TBase
>
,
uint8_t
const
,
uint8_t
>
;
size_t
count
=
0
;
while
(
size
!=
0
)
{
auto
currentStep
=
static_cast
<
int
>
(
std
::
min
(
size
,
step
));
MPICHECK
(
func
(
buffer
,
currentStep
,
dtype
,
args
...));
size
-=
currentStep
;
size_t
diff
=
static_cast
<
size_t
>
(
currentStep
)
*
elementSize
;
buffer
=
static_cast
<
TCast
*>
(
buffer
)
+
diff
;
++
count
;
}
return
count
;
}
#endif // ENABLE_MULTI_DEVICE
std
::
shared_ptr
<
MpiRequest
>
MpiComm
::
bcastAsync
(
void
*
buffer
,
size_t
size
,
MpiType
dtype
,
int
root
)
const
{
std
::
shared_ptr
<
MpiRequest
>
r
=
std
::
make_shared
<
MpiRequest
>
();
#if ENABLE_MULTI_DEVICE
invokeChunked
(
MPI_Ibcast
,
buffer
,
size
,
getMpiDtype
(
dtype
),
root
,
mComm
,
&
r
->
mRequest
);
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
return
r
;
}
std
::
shared_ptr
<
MpiRequest
>
MpiComm
::
bcastAsync
(
runtime
::
IBuffer
&
buf
,
int
root
)
const
{
TLLM_CHECK
(
buf
.
getMemoryType
()
!=
runtime
::
MemoryType
::
kGPU
);
return
bcastAsync
(
buf
.
data
(),
buf
.
getSizeInBytes
(),
MpiType
::
kBYTE
,
root
);
}
void
MpiComm
::
bcast
(
void
*
buffer
,
size_t
size
,
MpiType
dtype
,
int
root
)
const
{
#if ENABLE_MULTI_DEVICE
invokeChunked
(
MPI_Bcast
,
buffer
,
size
,
getMpiDtype
(
dtype
),
root
,
mComm
);
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
}
void
MpiComm
::
bcast
(
runtime
::
IBuffer
&
buf
,
int
root
)
const
{
bcast
(
buf
.
data
(),
buf
.
getSizeInBytes
(),
MpiType
::
kBYTE
,
root
);
}
std
::
shared_ptr
<
MpiRequest
>
MpiComm
::
sendAsync
(
void
const
*
buffer
,
size_t
size
,
MpiType
dtype
,
int
dest
,
int
tag
)
const
{
TLLM_LOG_DEBUG
(
"start MPI_Isend with size %d"
,
size
);
std
::
shared_ptr
<
MpiRequest
>
r
=
std
::
make_shared
<
MpiRequest
>
();
#if ENABLE_MULTI_DEVICE
invokeChunked
(
MPI_Isend
,
buffer
,
size
,
getMpiDtype
(
dtype
),
dest
,
tag
,
mComm
,
&
r
->
mRequest
);
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif
TLLM_LOG_DEBUG
(
"end MPI_Isend with size %d"
,
size
);
return
r
;
}
std
::
shared_ptr
<
MpiRequest
>
MpiComm
::
sendAsync
(
runtime
::
IBuffer
const
&
buf
,
int
dest
,
int
tag
)
const
{
return
sendAsync
(
buf
.
data
(),
buf
.
getSizeInBytes
(),
MpiType
::
kBYTE
,
dest
,
tag
);
}
void
MpiComm
::
send
(
void
const
*
buffer
,
size_t
size
,
MpiType
dtype
,
int
dest
,
int
tag
)
const
{
TLLM_LOG_DEBUG
(
"start MPI_Send with size %d"
,
size
);
#if ENABLE_MULTI_DEVICE
invokeChunked
(
MPI_Send
,
buffer
,
size
,
getMpiDtype
(
dtype
),
dest
,
tag
,
mComm
);
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
TLLM_LOG_DEBUG
(
"end MPI_Send with size %d"
,
size
);
}
void
MpiComm
::
send
(
runtime
::
IBuffer
const
&
buf
,
int
dest
,
int
tag
)
const
{
send
(
buf
.
data
(),
buf
.
getSizeInBytes
(),
MpiType
::
kBYTE
,
dest
,
tag
);
}
MPI_Status
MpiComm
::
recv
(
void
*
buffer
,
size_t
size
,
MpiType
dtype
,
int
source
,
int
tag
)
const
{
TLLM_LOG_DEBUG
(
"start MPI_Recv with size %d"
,
size
);
MPI_Status
status
{};
#if ENABLE_MULTI_DEVICE
invokeChunked
(
MPI_Recv
,
buffer
,
size
,
getMpiDtype
(
dtype
),
source
,
tag
,
mComm
,
&
status
);
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
TLLM_LOG_DEBUG
(
"end MPI_Recv with size %d"
,
size
);
return
status
;
}
MPI_Status
MpiComm
::
recv
(
runtime
::
IBuffer
&
buf
,
int
source
,
int
tag
)
const
{
return
recv
(
buf
.
data
(),
buf
.
getSizeInBytes
(),
MpiType
::
kBYTE
,
source
,
tag
);
}
MpiComm
MpiComm
::
split
(
int
color
,
int
key
)
const
{
MPI_Comm
splitComm
=
nullptr
;
#if ENABLE_MULTI_DEVICE
MPICHECK
(
MPI_Comm_split
(
mComm
,
color
,
key
,
&
splitComm
));
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
return
MpiComm
{
splitComm
,
true
};
}
void
MpiComm
::
allreduce
(
void
const
*
sendbuf
,
void
*
recvbuf
,
int
count
,
MpiType
dtype
,
MpiOp
op
)
const
{
#if ENABLE_MULTI_DEVICE
MPICHECK
(
MPI_Allreduce
(
sendbuf
,
recvbuf
,
count
,
getMpiDtype
(
dtype
),
getMpiOp
(
op
),
mComm
));
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
}
void
MpiComm
::
allgather
(
void
const
*
sendbuf
,
void
*
recvbuf
,
int
count
,
MpiType
dtype
)
const
{
#if ENABLE_MULTI_DEVICE
MPICHECK
(
MPI_Allgather
(
sendbuf
,
count
,
getMpiDtype
(
dtype
),
recvbuf
,
count
,
getMpiDtype
(
dtype
),
mComm
));
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
}
void
MpiComm
::
allgatherv
(
void
const
*
sendbuf
,
int
sendcount
,
MpiType
sendtype
,
void
*
recvbuf
,
std
::
vector
<
int
>
const
&
recvcounts
,
std
::
vector
<
int
>
const
&
displs
,
MpiType
recvtype
)
const
{
#if ENABLE_MULTI_DEVICE
MPICHECK
(
MPI_Allgatherv
(
sendbuf
,
sendcount
,
getMpiDtype
(
sendtype
),
recvbuf
,
recvcounts
.
data
(),
displs
.
data
(),
getMpiDtype
(
recvtype
),
mComm
));
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
}
void
MpiComm
::
mprobe
(
int
source
,
int
tag
,
MPI_Message
*
msg
,
MPI_Status
*
status
)
const
{
#if ENABLE_MULTI_DEVICE
MPICHECK
(
MPI_Mprobe
(
source
,
tag
,
mComm
,
msg
,
status
));
#else
TLLM_THROW
(
"Multi device support is disabled."
);
#endif // ENABLE_MULTI_DEVICE
}
bool
MpiComm
::
improbe
(
int
source
,
int
tag
,
MPI_Message
*
msg
,
MPI_Status
*
status
)
const
{
#if ENABLE_MULTI_DEVICE
int
flag
{
0
};
MPICHECK
(
MPI_Improbe
(
source
,
tag
,
mComm
,
&
flag
,
msg
,
status
));
return
flag
!=
0
;
#else
TLLM_THROW
(
"Multi device support is disabled."
);
return
false
;
#endif
}
bool
MpiComm
::
iprobe
(
int
source
,
int
tag
,
MPI_Status
*
status
)
const
{
#if ENABLE_MULTI_DEVICE
int
flag
{
0
};
MPICHECK
(
MPI_Iprobe
(
source
,
tag
,
mComm
,
&
flag
,
status
));
return
flag
!=
0
;
#else
TLLM_THROW
(
"Multi device support is disabled."
);
return
false
;
#endif
}
void
MpiComm
::
recvPoll
(
int
source
,
int
tag
,
int
periodMs
)
const
{
MPI_Status
status
;
while
(
!
iprobe
(
source
,
tag
,
&
status
))
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
milliseconds
(
periodMs
));
}
}
int
MpiComm
::
getRank
()
const
{
int
rank
=
0
;
#if ENABLE_MULTI_DEVICE
MPICHECK
(
MPI_Comm_rank
(
mComm
,
&
rank
));
#endif
return
rank
;
}
int
MpiComm
::
getSize
()
const
{
int
world_size
=
1
;
#if ENABLE_MULTI_DEVICE
MPICHECK
(
MPI_Comm_size
(
mComm
,
&
world_size
));
#endif
return
world_size
;
}
MpiComm
const
&
MpiComm
::
world
()
{
TLLM_LOG_TRACE
(
"%s start"
,
__PRETTY_FUNCTION__
);
static
MpiComm
commWorld
{
MPI_COMM_WORLD
,
false
};
initialize
();
TLLM_LOG_TRACE
(
"%s stop"
,
__PRETTY_FUNCTION__
);
return
commWorld
;
}
MpiComm
&
MpiComm
::
mutableSession
()
{
TLLM_LOG_TRACE
(
"%s start"
,
__PRETTY_FUNCTION__
);
static
MpiComm
commSession
{
MPI_COMM_WORLD
,
false
};
initialize
();
TLLM_LOG_TRACE
(
"%s stop"
,
__PRETTY_FUNCTION__
);
return
commSession
;
}
MpiComm
&
MpiComm
::
mutableLocalSession
()
{
TLLM_LOG_TRACE
(
"%s start"
,
__PRETTY_FUNCTION__
);
static
MpiComm
localSession
=
initLocalSession
();
TLLM_LOG_TRACE
(
"%s stop"
,
__PRETTY_FUNCTION__
);
return
localSession
;
}
void
MpiComm
::
refreshLocalSession
()
{
#if ENABLE_MULTI_DEVICE
static
std
::
mutex
mutex
;
std
::
unique_lock
lock
(
mutex
);
auto
initSessionRanks
=
getWorldRanks
(
MpiComm
::
session
());
auto
localSessionRanks
=
getWorldRanks
(
MpiComm
::
localSession
());
// Add to intersectionRanks in order of initSessionRanks
std
::
vector
<
int
>
intersectionRanks
;
std
::
unordered_set
<
int
>
localSessionRanksSet
(
localSessionRanks
.
begin
(),
localSessionRanks
.
end
());
for
(
auto
rank
:
initSessionRanks
)
{
if
(
localSessionRanksSet
.
find
(
rank
)
!=
localSessionRanksSet
.
end
())
{
intersectionRanks
.
push_back
(
rank
);
}
}
MPI_Group
worldGroup
=
nullptr
;
MPICHECK
(
MPI_Comm_group
(
MPI_COMM_WORLD
,
&
worldGroup
));
MPI_Group
localGroup
=
nullptr
;
MPICHECK
(
MPI_Group_incl
(
worldGroup
,
intersectionRanks
.
size
(),
intersectionRanks
.
data
(),
&
localGroup
));
MPI_Comm
localComm
=
nullptr
;
MPICHECK
(
MPI_Comm_create_group
(
MPI_COMM_WORLD
,
localGroup
,
intersectionRanks
.
front
(),
&
localComm
));
MpiComm
::
mutableLocalSession
().
mFreeComm
=
true
;
MpiComm
::
mutableLocalSession
()
=
MpiComm
{
localComm
,
false
};
TLLM_LOG_INFO
(
"Refreshed the MPI local session"
);
#endif // ENABLE_MULTI_DEVICE
}
MpiComm
::
MpiComm
(
MPI_Comm
g
,
bool
freeComm
)
:
mComm
{
g
}
,
mFreeComm
{
freeComm
}
{
TLLM_CHECK
(
mComm
!=
MPI_COMM_NULL
);
}
MpiComm
::~
MpiComm
()
noexcept
{
#if ENABLE_MULTI_DEVICE
if
(
mFreeComm
&&
mComm
)
{
if
(
MPI_Comm_free
(
&
mComm
)
!=
MPI_SUCCESS
)
{
TLLM_LOG_ERROR
(
"MPI_Comm_free failed"
);
}
}
#endif // ENABLE_MULTI_DEVICE
}
MpiComm
::
MpiComm
(
MpiComm
&&
comm
)
noexcept
:
mComm
{
comm
.
mComm
}
,
mFreeComm
{
comm
.
mFreeComm
}
{
comm
.
mFreeComm
=
false
;
}
MpiComm
&
MpiComm
::
operator
=
(
MpiComm
&&
comm
)
noexcept
{
this
->~
MpiComm
();
mComm
=
comm
.
mComm
;
mFreeComm
=
comm
.
mFreeComm
;
comm
.
mFreeComm
=
false
;
return
*
this
;
}
MpiWaitThread
::
MpiWaitThread
(
std
::
string
name
,
std
::
function
<
void
()
>
funcWait
,
std
::
function
<
void
()
>
funcSetup
)
:
mName
{
name
.
c_str
()}
,
mFuncWait
{
funcWait
}
,
mFuncSetup
{
funcSetup
}
{
TLLM_LOG_TRACE
(
"%s: %s start"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
mThread
=
std
::
make_unique
<
std
::
thread
>
(
&
MpiWaitThread
::
sideThread
,
this
);
TLLM_LOG_TRACE
(
"%s: %s stop"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
}
MpiWaitThread
::~
MpiWaitThread
()
{
TLLM_LOG_TRACE
(
"%s: %s start"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
waitStop
();
mShouldExit
.
store
(
true
);
notifyStart
();
mThread
->
join
();
mThread
.
reset
(
nullptr
);
TLLM_LOG_TRACE
(
"%s: %s stop"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
}
void
MpiWaitThread
::
sideThread
()
{
if
(
mFuncSetup
)
{
mFuncSetup
();
}
while
(
!
mShouldExit
.
load
())
{
notifyStop
();
waitStart
();
mFuncWait
();
}
}
void
MpiWaitThread
::
waitStart
()
{
TLLM_LOG_TRACE
(
"%s: %s start"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mMutex
);
mCondVar
.
wait
(
lock
,
[
this
]
{
return
mRunning
;
});
TLLM_LOG_TRACE
(
"%s: %s stop"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
}
void
MpiWaitThread
::
waitStop
()
{
TLLM_LOG_TRACE
(
"%s: %s start"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mMutex
);
mCondVar
.
wait
(
lock
,
[
this
]
{
return
!
mRunning
;
});
TLLM_LOG_TRACE
(
"%s: %s stop"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
}
void
MpiWaitThread
::
notifyStart
()
{
TLLM_LOG_TRACE
(
"%s: %s start"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mMutex
);
mRunning
=
true
;
mCondVar
.
notify_one
();
TLLM_LOG_TRACE
(
"%s: %s stop"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
}
void
MpiWaitThread
::
notifyStop
()
{
TLLM_LOG_TRACE
(
"%s: %s start"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mMutex
);
mRunning
=
false
;
mCondVar
.
notify_one
();
TLLM_LOG_TRACE
(
"%s: %s stop"
,
mName
.
c_str
(),
__PRETTY_FUNCTION__
);
}
}
// namespace tensorrt_llm::mpi
sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h
deleted
100644 → 0
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nvtx3/nvtx3.hpp>
#include <array>
namespace
tensorrt_llm
::
common
::
nvtx
{
inline
nvtx3
::
color
nextColor
()
{
#ifndef NVTX_DISABLE
constexpr
std
::
array
kColors
{
nvtx3
::
color
{
0xff00ff00
},
nvtx3
::
color
{
0xff0000ff
},
nvtx3
::
color
{
0xffffff00
},
nvtx3
::
color
{
0xffff00ff
},
nvtx3
::
color
{
0xff00ffff
},
nvtx3
::
color
{
0xffff0000
},
nvtx3
::
color
{
0xffffffff
}};
constexpr
auto
numColors
=
kColors
.
size
();
static
thread_local
std
::
size_t
colorId
=
0
;
auto
const
color
=
kColors
[
colorId
];
colorId
=
colorId
+
1
>=
numColors
?
0
:
colorId
+
1
;
return
color
;
#else
return
nvtx3
::
color
{
0
};
#endif
}
}
// namespace tensorrt_llm::common::nvtx
#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \
::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name)
#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range)
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp
deleted
100644 → 0
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "cuda.h"
#include <cstdint>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <functional>
#include <mutex>
#include <thread>
#ifdef _MSC_VER
#define FN_NAME __FUNCTION__
#else
#define FN_NAME __func__
#endif
#if ENABLE_MULTI_DEVICE
std
::
unordered_map
<
nvinfer1
::
DataType
,
ncclDataType_t
>*
getDtypeMap
()
{
static
std
::
unordered_map
<
nvinfer1
::
DataType
,
ncclDataType_t
>
dtypeMap
=
{{
nvinfer1
::
DataType
::
kFLOAT
,
ncclFloat32
},
{
nvinfer1
::
DataType
::
kHALF
,
ncclFloat16
},
{
nvinfer1
::
DataType
::
kBF16
,
ncclBfloat16
}};
return
&
dtypeMap
;
}
namespace
{
// Get NCCL unique ID for a group of ranks.
ncclUniqueId
getUniqueId
(
std
::
set
<
int
>
const
&
group
)
noexcept
{
auto
const
rank
=
COMM_SESSION
.
getRank
();
TLLM_LOG_TRACE
(
"%s start for rank %d"
,
__PRETTY_FUNCTION__
,
rank
);
ncclUniqueId
id
;
if
(
rank
==
*
group
.
begin
())
{
NCCLCHECK
(
ncclGetUniqueId
(
&
id
));
for
(
auto
it
=
std
::
next
(
std
::
begin
(
group
),
1
);
it
!=
group
.
end
();
++
it
)
{
COMM_SESSION
.
sendValue
(
id
,
*
it
,
0
);
}
}
else
{
COMM_SESSION
.
recvValue
(
id
,
*
group
.
begin
(),
0
);
}
TLLM_LOG_TRACE
(
"%s stop for rank %d"
,
__PRETTY_FUNCTION__
,
rank
);
return
id
;
}
}
// namespace
std
::
shared_ptr
<
ncclComm_t
>
getComm
(
std
::
set
<
int
>
const
&
group
)
{
auto
const
rank
=
COMM_SESSION
.
getRank
();
TLLM_LOG_TRACE
(
"%s start for rank %d"
,
__PRETTY_FUNCTION__
,
rank
);
static
std
::
map
<
std
::
set
<
int
>
,
std
::
shared_ptr
<
ncclComm_t
>>
commMap
;
static
std
::
mutex
mutex
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
std
::
ostringstream
oss
;
int
index
=
0
;
for
(
auto
const
&
rank
:
group
)
{
if
(
index
!=
0
)
{
oss
<<
","
;
}
oss
<<
rank
;
index
++
;
}
auto
groupStr
=
oss
.
str
();
auto
it
=
commMap
.
find
(
group
);
if
(
it
!=
commMap
.
end
())
{
auto
ncclComm
=
it
->
second
;
TLLM_LOG_TRACE
(
"NCCL comm for group(%s) is cached for rank %d"
,
groupStr
.
c_str
(),
rank
);
return
ncclComm
;
}
TLLM_LOG_TRACE
(
"Init NCCL comm for group(%s) for rank %d"
,
groupStr
.
c_str
(),
rank
);
ncclUniqueId
id
=
getUniqueId
(
group
);
int
groupRank
=
0
;
for
(
auto
const
&
currentRank
:
group
)
{
if
(
rank
==
currentRank
)
break
;
++
groupRank
;
}
TLLM_CHECK
(
groupRank
<
group
.
size
());
std
::
shared_ptr
<
ncclComm_t
>
ncclComm
(
new
ncclComm_t
,
[](
ncclComm_t
*
comm
)
{
ncclCommDestroy
(
*
comm
);
delete
comm
;
});
NCCLCHECK
(
ncclCommInitRank
(
ncclComm
.
get
(),
group
.
size
(),
id
,
groupRank
));
commMap
[
group
]
=
ncclComm
;
TLLM_LOG_TRACE
(
"%s stop for rank %d"
,
__PRETTY_FUNCTION__
,
rank
);
return
ncclComm
;
}
#endif // ENABLE_MULTI_DEVICE
void
const
*
tensorrt_llm
::
common
::
getCommSessionHandle
()
{
#if ENABLE_MULTI_DEVICE
return
&
COMM_SESSION
;
#else
return
nullptr
;
#endif // ENABLE_MULTI_DEVICE
}
namespace
{
// Get current cuda context, a default context will be created if there is no context.
inline
CUcontext
getCurrentCudaCtx
()
{
CUcontext
ctx
{};
CUresult
err
=
cuCtxGetCurrent
(
&
ctx
);
if
(
err
==
CUDA_ERROR_NOT_INITIALIZED
||
ctx
==
nullptr
)
{
TLLM_CUDA_CHECK
(
cudaFree
(
nullptr
));
err
=
cuCtxGetCurrent
(
&
ctx
);
}
TLLM_CHECK
(
err
==
CUDA_SUCCESS
);
return
ctx
;
}
// Helper to create per-cuda-context singleton managed by std::shared_ptr.
// Unlike conventional singletons, singleton created with this will be released
// when not needed, instead of on process exit.
// Objects of this class shall always be declared static / global, and shall never own CUDA
// resources.
template
<
typename
T
>
class
PerCudaCtxSingletonCreator
{
public:
using
CreatorFunc
=
std
::
function
<
std
::
unique_ptr
<
T
>
()
>
;
using
DeleterFunc
=
std
::
function
<
void
(
T
*
)
>
;
// creator returning std::unique_ptr is by design.
// It forces separation of memory for T and memory for control blocks.
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
// creator itself must not own CUDA resources. Only the object it creates can.
PerCudaCtxSingletonCreator
(
CreatorFunc
creator
,
DeleterFunc
deleter
)
:
mCreator
{
std
::
move
(
creator
)}
,
mDeleter
{
std
::
move
(
deleter
)}
{
}
std
::
shared_ptr
<
T
>
operator
()()
{
std
::
lock_guard
<
std
::
mutex
>
lk
{
mMutex
};
CUcontext
ctx
{
getCurrentCudaCtx
()};
std
::
shared_ptr
<
T
>
result
=
mObservers
[
ctx
].
lock
();
if
(
result
==
nullptr
)
{
// Create the resource and register with an observer.
result
=
std
::
shared_ptr
<
T
>
{
mCreator
().
release
(),
[
this
,
ctx
](
T
*
obj
)
{
if
(
obj
==
nullptr
)
{
return
;
}
mDeleter
(
obj
);
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std
::
shared_ptr
<
T
>
observedObjHolder
;
// Delay destroy to avoid dead lock.
std
::
lock_guard
<
std
::
mutex
>
lk
{
mMutex
};
// Must check observer again because another thread may created new instance for this ctx just
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
// operation, and the observer may be changed to observe another instance.
observedObjHolder
=
mObservers
.
at
(
ctx
).
lock
();
if
(
observedObjHolder
==
nullptr
)
{
mObservers
.
erase
(
ctx
);
}
}};
mObservers
.
at
(
ctx
)
=
result
;
}
return
result
;
}
private:
CreatorFunc
mCreator
;
DeleterFunc
mDeleter
;
mutable
std
::
mutex
mMutex
;
// CUDA resources are per-context.
std
::
unordered_map
<
CUcontext
,
std
::
weak_ptr
<
T
>>
mObservers
;
};
template
<
typename
T
>
class
PerThreadSingletonCreator
{
public:
using
CreatorFunc
=
std
::
function
<
std
::
unique_ptr
<
T
>
()
>
;
using
DeleterFunc
=
std
::
function
<
void
(
T
*
)
>
;
// creator returning std::unique_ptr is by design.
// It forces separation of memory for T and memory for control blocks.
// So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released.
// creator itself must not own CUDA resources. Only the object it creates can.
PerThreadSingletonCreator
(
CreatorFunc
creator
,
DeleterFunc
deleter
)
:
mCreator
{
std
::
move
(
creator
)}
,
mDeleter
{
std
::
move
(
deleter
)}
{
}
std
::
shared_ptr
<
T
>
operator
()()
{
std
::
lock_guard
<
std
::
mutex
>
lk
{
mMutex
};
std
::
thread
::
id
thread
=
std
::
this_thread
::
get_id
();
std
::
shared_ptr
<
T
>
result
=
mObservers
[
thread
].
lock
();
if
(
result
==
nullptr
)
{
// Create the resource and register with an observer.
result
=
std
::
shared_ptr
<
T
>
{
mCreator
().
release
(),
[
this
,
thread
](
T
*
obj
)
{
if
(
obj
==
nullptr
)
{
return
;
}
mDeleter
(
obj
);
// Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts
// frequently.
std
::
shared_ptr
<
T
>
observedObjHolder
;
// Delay destroy to avoid dead lock.
std
::
lock_guard
<
std
::
mutex
>
lk
{
mMutex
};
// Must check observer again because another thread may created new instance for this ctx just
// before we lock mMutex. We can't infer that the observer is stale from the fact that obj is
// destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic
// operation, and the observer may be changed to observe another instance.
observedObjHolder
=
mObservers
.
at
(
thread
).
lock
();
if
(
observedObjHolder
==
nullptr
)
{
mObservers
.
erase
(
thread
);
}
}};
mObservers
.
at
(
thread
)
=
result
;
}
return
result
;
}
private:
CreatorFunc
mCreator
;
DeleterFunc
mDeleter
;
mutable
std
::
mutex
mMutex
;
// CUDA resources are per-thread.
std
::
unordered_map
<
std
::
thread
::
id
,
std
::
weak_ptr
<
T
>>
mObservers
;
};
}
// namespace
std
::
shared_ptr
<
cublasHandle_t
>
getCublasHandle
()
{
static
PerThreadSingletonCreator
<
cublasHandle_t
>
creator
(
[]()
->
auto
{
auto
handle
=
std
::
unique_ptr
<
cublasHandle_t
>
(
new
cublasHandle_t
);
TLLM_CUDA_CHECK
(
cublasCreate
(
handle
.
get
()));
return
handle
;
},
[](
cublasHandle_t
*
handle
)
{
TLLM_CUDA_CHECK
(
cublasDestroy
(
*
handle
));
delete
handle
;
});
return
creator
();
}
std
::
shared_ptr
<
cublasLtHandle_t
>
getCublasLtHandle
()
{
static
PerThreadSingletonCreator
<
cublasLtHandle_t
>
creator
(
[]()
->
auto
{
auto
handle
=
std
::
unique_ptr
<
cublasLtHandle_t
>
(
new
cublasLtHandle_t
);
TLLM_CUDA_CHECK
(
cublasLtCreate
(
handle
.
get
()));
return
handle
;
},
[](
cublasLtHandle_t
*
handle
)
{
TLLM_CUDA_CHECK
(
cublasLtDestroy
(
*
handle
));
delete
handle
;
});
return
creator
();
}
std
::
shared_ptr
<
tensorrt_llm
::
common
::
CublasMMWrapper
>
getCublasMMWrapper
(
std
::
shared_ptr
<
cublasHandle_t
>
cublasHandle
,
std
::
shared_ptr
<
cublasLtHandle_t
>
cublasltHandle
,
cudaStream_t
stream
,
void
*
workspace
)
{
static
PerThreadSingletonCreator
<
tensorrt_llm
::
common
::
CublasMMWrapper
>
creator
(
[
cublasHandle
,
cublasltHandle
,
stream
,
workspace
]()
->
auto
{
auto
wrapper
=
std
::
unique_ptr
<
tensorrt_llm
::
common
::
CublasMMWrapper
>
(
new
tensorrt_llm
::
common
::
CublasMMWrapper
(
cublasHandle
,
cublasltHandle
,
stream
,
workspace
));
return
wrapper
;
},
[](
tensorrt_llm
::
common
::
CublasMMWrapper
*
wrapper
)
{
delete
wrapper
;
});
return
creator
();
}
sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h
deleted
100644 → 0
View file @
e81d7f11
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/workspace.h"
#include <NvInferRuntime.h>
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
#if ENABLE_MULTI_DEVICE
#include <nccl.h>
#endif // ENABLE_MULTI_DEVICE
#include <cstring>
#include <map>
#include <memory>
#include <nvml.h>
#include <optional>
#include <set>
#include <string>
#include <unordered_map>
namespace
tensorrt_llm
::
common
{
// Write values into buffer
template
<
typename
T
>
void
write
(
char
*&
buffer
,
T
const
&
val
)
{
std
::
memcpy
(
buffer
,
&
val
,
sizeof
(
T
));
buffer
+=
sizeof
(
T
);
}
// Read values from buffer
template
<
typename
T
>
void
read
(
char
const
*&
buffer
,
T
&
val
)
{
std
::
memcpy
(
&
val
,
buffer
,
sizeof
(
T
));
buffer
+=
sizeof
(
T
);
}
// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members.
// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and
// your clone() implementation is responsible for initializing such data members.
// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr.
template
<
typename
T
,
typename
Del
=
std
::
default_delete
<
T
>
>
class
UniqPtrWNullCopy
:
public
std
::
unique_ptr
<
T
,
Del
>
{
public:
using
std
::
unique_ptr
<
T
,
Del
>::
unique_ptr
;
// for compatibility with std::make_unique
explicit
UniqPtrWNullCopy
(
std
::
unique_ptr
<
T
,
Del
>&&
src
)
:
std
::
unique_ptr
<
T
,
Del
>::
unique_ptr
{
std
::
move
(
src
)}
{
}
// copy constructor produces nullptr
UniqPtrWNullCopy
(
UniqPtrWNullCopy
const
&
)
:
std
::
unique_ptr
<
T
,
Del
>::
unique_ptr
{}
{
}
};
// for testing only
void
const
*
getCommSessionHandle
();
}
// namespace tensorrt_llm::common
inline
bool
isBuilding
()
{
auto
constexpr
key
=
"IS_BUILDING"
;
auto
const
val
=
getenv
(
key
);
return
val
!=
nullptr
&&
std
::
string
(
val
)
==
"1"
;
}
#if ENABLE_MULTI_DEVICE
#define NCCLCHECK(cmd) \
do \
{ \
ncclResult_t r = cmd; \
if (r != ncclSuccess) \
{ \
printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
std
::
unordered_map
<
nvinfer1
::
DataType
,
ncclDataType_t
>*
getDtypeMap
();
std
::
shared_ptr
<
ncclComm_t
>
getComm
(
std
::
set
<
int
>
const
&
group
);
#endif // ENABLE_MULTI_DEVICE
//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally.
//! Get cublas and cublasLt handle for current cuda context
std
::
shared_ptr
<
cublasHandle_t
>
getCublasHandle
();
std
::
shared_ptr
<
cublasLtHandle_t
>
getCublasLtHandle
();
std
::
shared_ptr
<
tensorrt_llm
::
common
::
CublasMMWrapper
>
getCublasMMWrapper
(
std
::
shared_ptr
<
cublasHandle_t
>
cublasHandle
,
std
::
shared_ptr
<
cublasLtHandle_t
>
cublasltHandle
,
cudaStream_t
stream
,
void
*
workspace
);
#ifndef DEBUG
#define PLUGIN_CHECK(status) \
do \
{ \
if (status != 0) \
abort(); \
} while (0)
#define ASSERT_PARAM(exp) \
do \
{ \
if (!(exp)) \
return STATUS_BAD_PARAM; \
} while (0)
#define ASSERT_FAILURE(exp) \
do \
{ \
if (!(exp)) \
return STATUS_FAILURE; \
} while (0)
#define CSC(call, err) \
do \
{ \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) \
{ \
return err; \
} \
} while (0)
#define DEBUG_PRINTF(...) \
do \
{ \
} while (0)
#else
#define ASSERT_PARAM(exp) \
do \
{ \
if (!(exp)) \
{ \
fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_BAD_PARAM; \
} \
} while (0)
#define ASSERT_FAILURE(exp) \
do \
{ \
if (!(exp)) \
{ \
fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \
return STATUS_FAILURE; \
} \
} while (0)
#define CSC(call, err) \
do \
{ \
cudaError_t cudaStatus = call; \
if (cudaStatus != cudaSuccess) \
{ \
printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \
return err; \
} \
} while (0)
#define PLUGIN_CHECK(status) \
{ \
if (status != 0) \
{ \
DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \
abort(); \
} \
}
#define DEBUG_PRINTF(...) \
do \
{ \
printf(__VA_ARGS__); \
} while (0)
#endif // DEBUG
#define NVML_CHECK(cmd) \
do \
{ \
nvmlReturn_t r = cmd; \
if (r != NVML_SUCCESS) \
{ \
printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while (0)
sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h
0 → 100644
View file @
9602c2aa
/*
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstdint>
#include <optional>
#include <string>
namespace
tensorrt_llm
{
namespace
common
{
class
QuantMode
{
// [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py
public:
using
BaseType
=
std
::
uint32_t
;
explicit
constexpr
QuantMode
(
BaseType
value
)
noexcept
:
mValue
{
value
}
{
}
QuantMode
()
noexcept
=
default
;
constexpr
QuantMode
(
QuantMode
const
&
)
noexcept
=
default
;
constexpr
QuantMode
&
operator
=
(
QuantMode
const
&
other
)
noexcept
=
default
;
static
constexpr
QuantMode
none
()
noexcept
{
return
QuantMode
(
BaseType
(
0
));
}
static
constexpr
QuantMode
int4Weights
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
0
);
}
static
constexpr
QuantMode
int8Weights
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
1
);
}
static
constexpr
QuantMode
activations
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
2
);
}
static
constexpr
QuantMode
perChannelScaling
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
3
);
}
static
constexpr
QuantMode
perTokenScaling
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
4
);
}
static
constexpr
QuantMode
perGroupScaling
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
5
);
}
static
constexpr
QuantMode
int8KvCache
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
6
);
}
static
constexpr
QuantMode
fp8KvCache
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
7
);
}
static
constexpr
QuantMode
fp8Qdq
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
8
);
}
static
constexpr
QuantMode
fp8RowWise
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
3
|
BaseType
(
1u
)
<<
4
|
BaseType
(
1u
)
<<
9
);
}
static
constexpr
QuantMode
w4a8QServe
()
noexcept
{
return
QuantMode
(
BaseType
(
1u
)
<<
10
);
}
constexpr
BaseType
value
()
const
noexcept
{
return
mValue
;
}
constexpr
bool
isSet
(
QuantMode
const
&
mode
)
const
noexcept
{
return
(
mValue
&
mode
.
value
())
==
mode
.
value
();
}
constexpr
bool
hasInt4Weights
()
const
noexcept
{
return
isSet
(
int4Weights
());
}
constexpr
bool
hasInt8Weights
()
const
noexcept
{
return
isSet
(
int8Weights
());
}
constexpr
bool
hasActivations
()
const
noexcept
{
return
isSet
(
activations
());
}
constexpr
bool
hasPerChannelScaling
()
const
noexcept
{
return
isSet
(
perChannelScaling
());
}
constexpr
bool
hasPerTokenScaling
()
const
noexcept
{
return
isSet
(
perTokenScaling
());
}
constexpr
bool
hasPerGroupScaling
()
const
noexcept
{
return
isSet
(
perGroupScaling
());
}
constexpr
bool
hasStaticActivationScaling
()
const
noexcept
{
return
!
hasPerTokenScaling
();
}
constexpr
bool
hasInt8KvCache
()
const
noexcept
{
return
isSet
(
int8KvCache
());
}
constexpr
bool
hasFp8KvCache
()
const
noexcept
{
return
isSet
(
fp8KvCache
());
}
constexpr
bool
hasFp8Qdq
()
const
noexcept
{
return
isSet
(
fp8Qdq
());
}
constexpr
bool
hasFp8RowWise
()
const
noexcept
{
return
isSet
(
fp8RowWise
());
}
constexpr
bool
hasKvCacheQuant
()
const
noexcept
{
return
hasInt8KvCache
()
||
hasFp8KvCache
();
}
static
constexpr
QuantMode
fromDescription
(
bool
quantizeWeights
=
false
,
bool
quantizeActivations
=
false
,
bool
perToken
=
false
,
bool
perChannel
=
false
,
bool
perGroup
=
false
,
bool
useInt4Weights
=
false
,
bool
useInt8KvCache
=
false
,
bool
useFp8KvCache
=
false
,
bool
useFp8Qdq
=
false
,
bool
useFp8RowWise
=
false
,
bool
useW4a8QServe
=
false
)
{
QuantMode
quantMode
{};
if
(
quantizeWeights
)
{
if
(
useInt4Weights
)
quantMode
+=
int4Weights
();
else
quantMode
+=
int8Weights
();
}
if
(
quantizeActivations
)
{
quantMode
+=
activations
();
}
if
(
perChannel
)
{
quantMode
+=
QuantMode
::
perChannelScaling
();
}
if
(
perToken
)
{
quantMode
+=
QuantMode
::
perTokenScaling
();
}
if
(
perGroup
)
{
quantMode
+=
QuantMode
::
perGroupScaling
();
}
if
(
useInt8KvCache
)
{
quantMode
+=
int8KvCache
();
}
if
(
useFp8KvCache
)
{
quantMode
+=
fp8KvCache
();
}
if
(
useFp8Qdq
)
{
quantMode
+=
fp8Qdq
();
}
if
(
useFp8RowWise
)
{
quantMode
+=
fp8RowWise
();
}
if
(
useW4a8QServe
)
{
quantMode
+=
w4a8QServe
();
}
return
quantMode
;
}
static
constexpr
QuantMode
useSmoothQuant
(
bool
perToken
=
false
,
bool
perChannel
=
false
)
{
return
fromDescription
(
true
,
true
,
perToken
,
perChannel
);
}
static
constexpr
QuantMode
useQServe
(
bool
perGroup
)
{
return
fromDescription
(
true
,
true
,
false
,
false
,
perGroup
,
true
,
false
,
false
,
false
,
false
,
true
);
}
static
constexpr
QuantMode
useWeightOnly
(
bool
useInt4Weights
=
false
,
bool
perGroup
=
false
)
{
return
fromDescription
(
true
,
false
,
false
,
false
,
perGroup
,
useInt4Weights
);
}
static
QuantMode
const
fromQuantAlgo
(
std
::
optional
<
std
::
string
>
quantAlgo
=
std
::
nullopt
,
std
::
optional
<
std
::
string
>
kvCacheQuantAlgo
=
std
::
nullopt
)
{
QuantMode
quantMode
{};
if
(
quantAlgo
==
"W8A16"
)
{
quantMode
=
useWeightOnly
(
false
,
false
);
}
else
if
(
quantAlgo
==
"W4A16"
)
{
quantMode
=
useWeightOnly
(
true
,
false
);
}
else
if
(
quantAlgo
==
"W4A16_AWQ"
)
{
quantMode
=
useWeightOnly
(
true
,
true
);
}
else
if
(
quantAlgo
==
"W4A8_AWQ"
)
{
quantMode
=
useWeightOnly
(
true
,
true
);
}
else
if
(
quantAlgo
==
"W4A8_QSERVE_PER_GROUP"
)
{
quantMode
=
useQServe
(
false
);
}
else
if
(
quantAlgo
==
"W4A8_QSERVE_PER_CHANNEL"
)
{
quantMode
=
useQServe
(
true
);
}
else
if
(
quantAlgo
==
"W4A16_GPTQ"
)
{
quantMode
=
useWeightOnly
(
true
,
true
);
}
else
if
(
quantAlgo
==
"W8A8_SQ_PER_CHANNEL"
)
{
quantMode
=
useSmoothQuant
(
false
,
true
);
}
else
if
(
quantAlgo
==
"W8A8_SQ_PER_TENSOR_PLUGIN"
)
{
quantMode
=
useSmoothQuant
(
false
,
false
);
}
else
if
(
quantAlgo
==
"W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN"
)
{
quantMode
=
useSmoothQuant
(
true
,
true
);
}
else
if
(
quantAlgo
==
"W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN"
)
{
quantMode
=
useSmoothQuant
(
false
,
true
);
}
else
if
(
quantAlgo
==
"W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN"
)
{
quantMode
=
useSmoothQuant
(
true
,
false
);
}
else
if
(
quantAlgo
==
"FP8"
)
{
quantMode
=
fromDescription
(
false
,
false
,
false
,
false
,
false
,
false
,
false
,
false
,
true
);
}
else
if
(
quantAlgo
==
"FP8_ROWWISE"
)
{
quantMode
=
fromDescription
(
false
,
false
,
true
,
true
,
false
,
false
,
false
,
false
,
false
,
true
);
}
if
(
kvCacheQuantAlgo
==
"INT8"
)
{
quantMode
+=
int8KvCache
();
}
else
if
(
kvCacheQuantAlgo
==
"FP8"
)
{
quantMode
+=
fp8KvCache
();
}
return
quantMode
;
}
constexpr
QuantMode
operator
+
(
QuantMode
const
&
other
)
const
noexcept
{
return
QuantMode
(
mValue
|
other
.
mValue
);
}
constexpr
QuantMode
&
operator
+=
(
QuantMode
const
&
other
)
noexcept
{
return
*
this
=
*
this
+
other
;
}
constexpr
QuantMode
operator
-
(
QuantMode
const
&
other
)
const
noexcept
{
return
QuantMode
(
mValue
&
~
other
.
mValue
);
}
constexpr
QuantMode
&
operator
-=
(
QuantMode
const
&
other
)
noexcept
{
return
*
this
=
*
this
-
other
;
}
constexpr
bool
operator
==
(
QuantMode
const
&
other
)
const
noexcept
{
return
mValue
==
other
.
mValue
;
}
constexpr
bool
operator
!=
(
QuantMode
const
&
other
)
const
noexcept
{
return
!
(
*
this
==
other
);
}
private:
BaseType
mValue
{
0
};
};
}
// namespace common
}
// namespace tensorrt_llm
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment