Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
0a7c8614
Commit
0a7c8614
authored
Nov 21, 2025
by
fengzch-das
Browse files
Revert "hipify code"
This reverts commit
1a8114bf
parent
1a8114bf
Pipeline
#3050
failed with stages
in 0 seconds
Changes
50
Pipelines
1
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
248 additions
and
254 deletions
+248
-254
nunchaku/csrc/gemm.h
nunchaku/csrc/gemm.h
+3
-3
nunchaku/csrc/gemm88.h
nunchaku/csrc/gemm88.h
+2
-2
nunchaku/csrc/utils.h
nunchaku/csrc/utils.h
+10
-10
src/FluxModel.cpp
src/FluxModel.cpp
+60
-60
src/Linear.cpp
src/Linear.cpp
+23
-23
src/Linear.h
src/Linear.h
+1
-1
src/Module.h
src/Module.h
+5
-5
src/SanaModel.cpp
src/SanaModel.cpp
+11
-11
src/Serialization.cpp
src/Serialization.cpp
+5
-5
src/Serialization.h
src/Serialization.h
+5
-5
src/Tensor.h
src/Tensor.h
+39
-40
src/common.h
src/common.h
+32
-32
src/interop/torch.cpp
src/interop/torch.cpp
+4
-4
src/kernels/activation_kernels.cu
src/kernels/activation_kernels.cu
+10
-11
src/kernels/activation_kernels_impl.cuh
src/kernels/activation_kernels_impl.cuh
+0
-1
src/kernels/awq/dequantize.cuh
src/kernels/awq/dequantize.cuh
+7
-7
src/kernels/awq/gemm_awq.cu
src/kernels/awq/gemm_awq.cu
+16
-17
src/kernels/awq/gemv_awq.cu
src/kernels/awq/gemv_awq.cu
+9
-10
src/kernels/awq/semaphore.h
src/kernels/awq/semaphore.h
+2
-3
src/kernels/dispatch_utils.h
src/kernels/dispatch_utils.h
+4
-4
No files found.
nunchaku/csrc/gemm.h
View file @
0a7c8614
...
@@ -12,8 +12,8 @@ public:
...
@@ -12,8 +12,8 @@ public:
spdlog
::
info
(
"Initializing QuantizedGEMM"
);
spdlog
::
info
(
"Initializing QuantizedGEMM"
);
size_t
val
=
0
;
size_t
val
=
0
;
checkCUDA
(
hip
DeviceSetLimit
(
hip
LimitStackSize
,
8192
));
checkCUDA
(
cuda
DeviceSetLimit
(
cuda
LimitStackSize
,
8192
));
checkCUDA
(
hip
DeviceGetLimit
(
&
val
,
hip
LimitStackSize
));
checkCUDA
(
cuda
DeviceGetLimit
(
&
val
,
cuda
LimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
spdlog
::
debug
(
"Stack={}"
,
val
);
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
net
=
std
::
make_unique
<
GEMM_W4A4
>
((
int
)
in_features
,
...
@@ -42,7 +42,7 @@ public:
...
@@ -42,7 +42,7 @@ public:
std
::
string
dumpTensorBF16
(
Tensor
x
)
{
std
::
string
dumpTensorBF16
(
Tensor
x
)
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
for
(
int
i
=
0
;
i
<
256
;
i
++
)
{
for
(
int
i
=
0
;
i
<
256
;
i
++
)
{
ss
<<
spdlog
::
fmt_lib
::
format
(
"{:.3f} "
,
(
float
)(
x
.
data_ptr
<
__
hip
_bfloat16
>
()[
i
]));
ss
<<
spdlog
::
fmt_lib
::
format
(
"{:.3f} "
,
(
float
)(
x
.
data_ptr
<
__
nv
_bfloat16
>
()[
i
]));
}
}
ss
<<
std
::
endl
;
ss
<<
std
::
endl
;
return
ss
.
str
();
return
ss
.
str
();
...
...
nunchaku/csrc/gemm88.h
View file @
0a7c8614
...
@@ -12,8 +12,8 @@ public:
...
@@ -12,8 +12,8 @@ public:
spdlog
::
info
(
"Initializing QuantizedGEMM88"
);
spdlog
::
info
(
"Initializing QuantizedGEMM88"
);
size_t
val
=
0
;
size_t
val
=
0
;
checkCUDA
(
hip
DeviceSetLimit
(
hip
LimitStackSize
,
8192
));
checkCUDA
(
cuda
DeviceSetLimit
(
cuda
LimitStackSize
,
8192
));
checkCUDA
(
hip
DeviceGetLimit
(
&
val
,
hip
LimitStackSize
));
checkCUDA
(
cuda
DeviceGetLimit
(
&
val
,
cuda
LimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
spdlog
::
debug
(
"Stack={}"
,
val
);
net
=
std
::
make_unique
<
GEMM_W8A8
>
(
net
=
std
::
make_unique
<
GEMM_W8A8
>
(
...
...
nunchaku/csrc/utils.h
View file @
0a7c8614
...
@@ -8,27 +8,27 @@ namespace nunchaku::utils {
...
@@ -8,27 +8,27 @@ namespace nunchaku::utils {
void
set_cuda_stack_limit
(
int64_t
newval
)
{
void
set_cuda_stack_limit
(
int64_t
newval
)
{
size_t
val
=
0
;
size_t
val
=
0
;
checkCUDA
(
hip
DeviceSetLimit
(
hip
LimitStackSize
,
(
size_t
)
newval
));
checkCUDA
(
cuda
DeviceSetLimit
(
cuda
LimitStackSize
,
(
size_t
)
newval
));
checkCUDA
(
hip
DeviceGetLimit
(
&
val
,
hip
LimitStackSize
));
checkCUDA
(
cuda
DeviceGetLimit
(
&
val
,
cuda
LimitStackSize
));
spdlog
::
debug
(
"Stack={}"
,
val
);
spdlog
::
debug
(
"Stack={}"
,
val
);
}
}
void
disable_memory_auto_release
()
{
void
disable_memory_auto_release
()
{
int
device
;
int
device
;
checkCUDA
(
hip
GetDevice
(
&
device
));
checkCUDA
(
cuda
GetDevice
(
&
device
));
hip
MemPool_t
mempool
;
cuda
MemPool_t
mempool
;
checkCUDA
(
hip
DeviceGetDefaultMemPool
(
&
mempool
,
device
));
checkCUDA
(
cuda
DeviceGetDefaultMemPool
(
&
mempool
,
device
));
uint64_t
threshold
=
UINT64_MAX
;
uint64_t
threshold
=
UINT64_MAX
;
checkCUDA
(
hip
MemPoolSetAttribute
(
mempool
,
hip
MemPoolAttrReleaseThreshold
,
&
threshold
));
checkCUDA
(
cuda
MemPoolSetAttribute
(
mempool
,
cuda
MemPoolAttrReleaseThreshold
,
&
threshold
));
}
}
void
trim_memory
()
{
void
trim_memory
()
{
int
device
;
int
device
;
checkCUDA
(
hip
GetDevice
(
&
device
));
checkCUDA
(
cuda
GetDevice
(
&
device
));
hip
MemPool_t
mempool
;
cuda
MemPool_t
mempool
;
checkCUDA
(
hip
DeviceGetDefaultMemPool
(
&
mempool
,
device
));
checkCUDA
(
cuda
DeviceGetDefaultMemPool
(
&
mempool
,
device
));
size_t
bytesToKeep
=
0
;
size_t
bytesToKeep
=
0
;
checkCUDA
(
hip
MemPoolTrimTo
(
mempool
,
bytesToKeep
));
checkCUDA
(
cuda
MemPoolTrimTo
(
mempool
,
bytesToKeep
));
}
}
void
set_faster_i2f_mode
(
std
::
string
mode
)
{
void
set_faster_i2f_mode
(
std
::
string
mode
)
{
...
...
src/FluxModel.cpp
View file @
0a7c8614
This diff is collapsed.
Click to expand it.
src/Linear.cpp
View file @
0a7c8614
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "kernels/awq/gemv_awq.h"
#include "kernels/awq/gemv_awq.h"
#include "kernels/dwconv.h"
#include "kernels/dwconv.h"
#include <nvtx3/
roctracer/roctx
.h>
#include <nvtx3/
nvToolsExt
.h>
using
namespace
nunchaku
;
using
namespace
nunchaku
;
...
@@ -117,7 +117,7 @@ GEMM_W4A4::GEMM_W4A4(
...
@@ -117,7 +117,7 @@ GEMM_W4A4::GEMM_W4A4(
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
);
wtscale
,
"wtscale"
,
ParamFlags
::
Optional
)(
wcscales
,
"wcscales"
,
ParamFlags
::
Optional
);
#if NO_LORA_FUSION
#if NO_LORA_FUSION
checkCUBLAS
(
hip
blasCreate
(
&
handle
));
checkCUBLAS
(
cu
blasCreate
(
&
handle
));
#endif
#endif
}
}
...
@@ -140,7 +140,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
...
@@ -140,7 +140,7 @@ void GEMM_W4A4::loadParam(std::string key, Tensor &dst, Tensor src) {
}
else
if
(
key
==
"wtscale"
)
{
}
else
if
(
key
==
"wtscale"
)
{
assert
(
src
.
numel
()
==
1
);
assert
(
src
.
numel
()
==
1
);
if
(
src
.
dtype
()
==
Tensor
::
BF16
)
{
if
(
src
.
dtype
()
==
Tensor
::
BF16
)
{
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
__
hip
_bfloat16
>
());
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
__
nv
_bfloat16
>
());
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP16
)
{
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP16
)
{
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
half
>
());
*
dst
.
data_ptr
<
float
>
()
=
float
(
*
src
.
data_ptr
<
half
>
());
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP32
)
{
}
else
if
(
src
.
dtype
()
==
Tensor
::
FP32
)
{
...
@@ -242,15 +242,15 @@ void GEMM_W4A4::forward(Tensor x,
...
@@ -242,15 +242,15 @@ void GEMM_W4A4::forward(Tensor x,
qact
.
is_unsigned
,
qact
.
is_unsigned
,
this
->
lora_scales
);
this
->
lora_scales
);
roc
txRangePushA
(
"LoraUp"
);
nv
txRangePushA
(
"LoraUp"
);
static
const
half
one
=
1.0
;
static
const
half
one
=
1.0
;
static
const
half
zero
=
0.0
;
static
const
half
zero
=
0.0
;
// lora_up: [M, R] * [OC, R] => [M, OC]
// lora_up: [M, R] * [OC, R] => [M, OC]
// cublas view: [OC, R] * [M, R]^T
// cublas view: [OC, R] * [M, R]^T
checkCUBLAS
(
hip
blasHgemm
(
handle
,
checkCUBLAS
(
cu
blasHgemm
(
handle
,
HIP
BLAS_OP_T
,
CU
BLAS_OP_T
,
HIP
BLAS_OP_N
,
CU
BLAS_OP_N
,
this
->
out_features
,
this
->
out_features
,
M
,
M
,
this
->
lora_rank
,
this
->
lora_rank
,
...
@@ -263,7 +263,7 @@ void GEMM_W4A4::forward(Tensor x,
...
@@ -263,7 +263,7 @@ void GEMM_W4A4::forward(Tensor x,
out
.
data_ptr
<
half
>
(),
out
.
data_ptr
<
half
>
(),
this
->
out_features
));
this
->
out_features
));
roc
txRangePop
();
nv
txRangePop
();
#endif
#endif
}
}
...
@@ -380,7 +380,7 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
...
@@ -380,7 +380,7 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
qact
.
is_unsigned
,
qact
.
is_unsigned
,
this
->
lora_scales
);
this
->
lora_scales
);
roc
txRangePushA
(
"LoraUp"
);
nv
txRangePushA
(
"LoraUp"
);
static
const
half
one
=
1.0
;
static
const
half
one
=
1.0
;
static
const
half
zero
=
0.0
;
static
const
half
zero
=
0.0
;
...
@@ -388,9 +388,9 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
...
@@ -388,9 +388,9 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
// lora_up: [M, R] * [OC, R]^T => [M, OC]
// lora_up: [M, R] * [OC, R]^T => [M, OC]
// cublas view: [R, OC]^T * [R, M] => [OC, M]
// cublas view: [R, OC]^T * [R, M] => [OC, M]
// lora_up layout wrong?
// lora_up layout wrong?
checkCUBLAS
(
hip
blasHgemm
(
handle
,
checkCUBLAS
(
cu
blasHgemm
(
handle
,
HIP
BLAS_OP_T
,
CU
BLAS_OP_T
,
HIP
BLAS_OP_N
,
CU
BLAS_OP_N
,
this
->
out_features
,
this
->
out_features
,
M
,
M
,
this
->
lora_rank
,
this
->
lora_rank
,
...
@@ -403,16 +403,16 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
...
@@ -403,16 +403,16 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
out
.
data_ptr
<
half
>
(),
out
.
data_ptr
<
half
>
(),
this
->
out_features
));
this
->
out_features
));
roc
txRangePop
();
nv
txRangePop
();
if
(
fuse
==
FuseOptions
::
GELU_QUANT
)
{
if
(
fuse
==
FuseOptions
::
GELU_QUANT
)
{
roc
txRangePushA
(
"LoraDown"
);
nv
txRangePushA
(
"LoraDown"
);
// IC is for next lora (OC of this layer)
// IC is for next lora (OC of this layer)
// lora_down: [M, IC] * [IC, R] => [M, R]
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M] => [R, M]
// cublas view: [R, IC] * [IC, M] => [R, M]
checkCUBLAS
(
hip
blasHgemm
(
handle
,
checkCUBLAS
(
cu
blasHgemm
(
handle
,
HIP
BLAS_OP_N
,
CU
BLAS_OP_N
,
HIP
BLAS_OP_N
,
CU
BLAS_OP_N
,
this
->
lora_rank
,
this
->
lora_rank
,
M
,
M
,
this
->
out_features
,
this
->
out_features
,
...
@@ -427,7 +427,7 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
...
@@ -427,7 +427,7 @@ GEMM_W4A4::forward_quant(QuantizedActivation qact, FuseOptions fuse, GEMM_W4A4 *
out
=
{};
out
=
{};
roc
txRangePop
();
nv
txRangePop
();
}
}
#endif
#endif
...
@@ -473,13 +473,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
...
@@ -473,13 +473,13 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
static
const
half
one
=
1.0
;
static
const
half
one
=
1.0
;
static
const
half
zero
=
0.0
;
static
const
half
zero
=
0.0
;
roc
txRangePushA
(
"LoraDown"
);
nv
txRangePushA
(
"LoraDown"
);
// lora_down: [M, IC] * [IC, R] => [M, R]
// lora_down: [M, IC] * [IC, R] => [M, R]
// cublas view: [R, IC] * [IC, M]
// cublas view: [R, IC] * [IC, M]
checkCUBLAS
(
hip
blasHgemm
(
handle
,
checkCUBLAS
(
cu
blasHgemm
(
handle
,
HIP
BLAS_OP_N
,
CU
BLAS_OP_N
,
HIP
BLAS_OP_N
,
CU
BLAS_OP_N
,
this
->
lora_rank
,
this
->
lora_rank
,
M
,
M
,
this
->
in_features
,
this
->
in_features
,
...
@@ -492,7 +492,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
...
@@ -492,7 +492,7 @@ GEMM_W4A4::QuantizedActivation GEMM_W4A4::quantize(Tensor x, bool fuse_glu) {
qact
.
lora_act
.
data_ptr
<
half
>
(),
qact
.
lora_act
.
data_ptr
<
half
>
(),
this
->
lora_rank
));
this
->
lora_rank
));
roc
txRangePop
();
nv
txRangePop
();
kernels
::
quantize_w4a4_act
(
x
,
qact
.
act
,
qact
.
ascales
);
kernels
::
quantize_w4a4_act
(
x
,
qact
.
act
,
qact
.
ascales
);
...
...
src/Linear.h
View file @
0a7c8614
...
@@ -116,7 +116,7 @@ public:
...
@@ -116,7 +116,7 @@ public:
Tensor
wtscale
;
Tensor
wtscale
;
Tensor
wcscales
;
Tensor
wcscales
;
hip
blasHandle_t
handle
;
cu
blasHandle_t
handle
;
};
};
class
GEMM_W8A8
:
public
Module
{
class
GEMM_W8A8
:
public
Module
{
...
...
src/Module.h
View file @
0a7c8614
...
@@ -258,7 +258,7 @@ private:
...
@@ -258,7 +258,7 @@ private:
waitEvent
(
eventLoadDone
.
get
());
waitEvent
(
eventLoadDone
.
get
());
funcCompute
(
layer
);
funcCompute
(
layer
);
nextComputeDone
=
std
::
make_unique
<
CUDAEventWrapper
>
();
nextComputeDone
=
std
::
make_unique
<
CUDAEventWrapper
>
();
checkCUDA
(
hip
EventRecord
(
nextComputeDone
->
event
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
checkCUDA
(
cuda
EventRecord
(
nextComputeDone
->
event
,
getCurrent
CUDA
Stream
()));
workaroundFlush
();
workaroundFlush
();
}
}
...
@@ -272,7 +272,7 @@ private:
...
@@ -272,7 +272,7 @@ private:
funcLoad
(
layer
+
1
);
funcLoad
(
layer
+
1
);
}
}
nextLoadDone
=
std
::
make_unique
<
CUDAEventWrapper
>
();
nextLoadDone
=
std
::
make_unique
<
CUDAEventWrapper
>
();
checkCUDA
(
hip
EventRecord
(
nextLoadDone
->
event
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
checkCUDA
(
cuda
EventRecord
(
nextLoadDone
->
event
,
getCurrent
CUDA
Stream
()));
workaroundFlush
();
workaroundFlush
();
}
}
...
@@ -287,7 +287,7 @@ private:
...
@@ -287,7 +287,7 @@ private:
if
(
!
event
)
{
if
(
!
event
)
{
return
;
return
;
}
}
checkCUDA
(
hip
StreamWaitEvent
(
getCurrent
HIP
Stream
MasqueradingAsCUDA
(),
event
->
event
));
checkCUDA
(
cuda
StreamWaitEvent
(
getCurrent
CUDA
Stream
(),
event
->
event
));
}
}
// WDDM prevents multiple streams run concurrently
// WDDM prevents multiple streams run concurrently
...
@@ -312,12 +312,12 @@ private:
...
@@ -312,12 +312,12 @@ private:
if
(
!
needWorkaround
)
{
if
(
!
needWorkaround
)
{
return
;
return
;
}
}
hip
StreamQuery
(
getCurrent
HIP
Stream
MasqueradingAsCUDA
());
cuda
StreamQuery
(
getCurrent
CUDA
Stream
());
}
}
void
workaroundSynchronize
()
{
void
workaroundSynchronize
()
{
if
(
!
needWorkaround
)
{
if
(
!
needWorkaround
)
{
return
;
return
;
}
}
checkCUDA
(
hip
EventSynchronize
(
eventComputeDone
->
event
));
checkCUDA
(
cuda
EventSynchronize
(
eventComputeDone
->
event
));
}
}
};
};
src/SanaModel.cpp
View file @
0a7c8614
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "flash_api.h"
#include "flash_api.h"
#include "kernels/misc_kernels.h"
#include "kernels/misc_kernels.h"
#include <nvtx3/
roctracer/roctx
.h>
#include <nvtx3/
nvToolsExt
.h>
using
spdlog
::
fmt_lib
::
format
;
using
spdlog
::
fmt_lib
::
format
;
using
namespace
nunchaku
;
using
namespace
nunchaku
;
...
@@ -241,9 +241,9 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
...
@@ -241,9 +241,9 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
bool
pag
,
bool
pag
,
bool
cfg
)
{
bool
cfg
)
{
roc
txRangePushA
(
"SanaLinearTransformerBlock"
);
nv
txRangePushA
(
"SanaLinearTransformerBlock"
);
roc
txRangePushA
(
"chunk"
);
nv
txRangePushA
(
"chunk"
);
// Tensor ones = Tensor::ones({hidden_size}, Tensor::FP16, x.device());
// Tensor ones = Tensor::ones({hidden_size}, Tensor::FP16, x.device());
...
@@ -262,10 +262,10 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
...
@@ -262,10 +262,10 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
auto
&&
[
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
chunked
;
auto
&&
[
shift_msa
,
scale_msa
,
gate_msa
,
shift_mlp
,
scale_mlp
,
gate_mlp
]
=
chunked
;
// auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(timestep);
// auto &&[shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] = kernels::split_mod<6>(timestep);
roc
txRangePop
();
nv
txRangePop
();
{
{
roc
txRangePushA
(
"LinearAttention"
);
nv
txRangePushA
(
"LinearAttention"
);
Tensor
residual
=
hidden_states
;
Tensor
residual
=
hidden_states
;
Tensor
norm_hidden_states
=
norm1
.
forward
(
hidden_states
);
Tensor
norm_hidden_states
=
norm1
.
forward
(
hidden_states
);
...
@@ -279,11 +279,11 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
...
@@ -279,11 +279,11 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
hidden_states
=
attn_output
;
hidden_states
=
attn_output
;
roc
txRangePop
();
nv
txRangePop
();
}
}
{
{
roc
txRangePushA
(
"CrossAttention"
);
nv
txRangePushA
(
"CrossAttention"
);
debug
(
"norm_hidden_states_cross"
,
hidden_states
);
debug
(
"norm_hidden_states_cross"
,
hidden_states
);
Tensor
attn_output
=
cross_attn
.
forward
(
hidden_states
,
encoder_hidden_states
,
cu_seqlens_img
,
cu_seqlens_txt
);
Tensor
attn_output
=
cross_attn
.
forward
(
hidden_states
,
encoder_hidden_states
,
cu_seqlens_img
,
cu_seqlens_txt
);
...
@@ -293,11 +293,11 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
...
@@ -293,11 +293,11 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
hidden_states
=
attn_output
;
hidden_states
=
attn_output
;
roc
txRangePop
();
nv
txRangePop
();
}
}
{
{
roc
txRangePushA
(
"Feed-forward"
);
nv
txRangePushA
(
"Feed-forward"
);
debug
(
"hidden_states_ff"
,
hidden_states
);
debug
(
"hidden_states_ff"
,
hidden_states
);
Tensor
norm_hidden_states
=
norm2
.
forward
(
hidden_states
);
Tensor
norm_hidden_states
=
norm2
.
forward
(
hidden_states
);
...
@@ -311,10 +311,10 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
...
@@ -311,10 +311,10 @@ Tensor SanaLinearTransformerBlock::forward(Tensor hidden_states,
hidden_states
=
ff_output
;
hidden_states
=
ff_output
;
roc
txRangePop
();
nv
txRangePop
();
}
}
roc
txRangePop
();
nv
txRangePop
();
debug
(
"hidden_states_out"
,
hidden_states
);
debug
(
"hidden_states_out"
,
hidden_states
);
...
...
src/Serialization.cpp
View file @
0a7c8614
...
@@ -121,15 +121,15 @@ SafeTensors::SafeTensors(const std::string &filename) {
...
@@ -121,15 +121,15 @@ SafeTensors::SafeTensors(const std::string &filename) {
auto
methodPrivate
=
[
&
]()
{
auto
methodPrivate
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplPrivate
>
(
filename
);
this
->
mapped
=
std
::
make_unique
<
MMapImplPrivate
>
(
filename
);
checkCUDA
(
checkCUDA
(
hip
HostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
hip
HostRegisterPortable
));
cuda
HostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
cuda
HostRegisterPortable
));
this
->
hostRegistered
=
true
;
this
->
hostRegistered
=
true
;
this
->
memoryPinned
=
true
;
this
->
memoryPinned
=
true
;
};
};
auto
methodMio
=
[
&
]()
{
auto
methodMio
=
[
&
]()
{
this
->
mapped
=
std
::
make_unique
<
MMapImplMio
>
(
filename
);
this
->
mapped
=
std
::
make_unique
<
MMapImplMio
>
(
filename
);
checkCUDA
(
hip
HostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
checkCUDA
(
cuda
HostRegister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()),
this
->
mapped
->
size
(),
this
->
mapped
->
size
(),
hip
HostRegisterPortable
|
hip
HostRegisterReadOnly
));
cuda
HostRegisterPortable
|
cuda
HostRegisterReadOnly
));
this
->
hostRegistered
=
true
;
this
->
hostRegistered
=
true
;
this
->
memoryPinned
=
true
;
this
->
memoryPinned
=
true
;
};
};
...
@@ -183,8 +183,8 @@ SafeTensors::SafeTensors(const std::string &filename) {
...
@@ -183,8 +183,8 @@ SafeTensors::SafeTensors(const std::string &filename) {
SafeTensors
::~
SafeTensors
()
{
SafeTensors
::~
SafeTensors
()
{
if
(
this
->
hostRegistered
)
{
if
(
this
->
hostRegistered
)
{
if
(
hip
HostUnregister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()))
!=
hip
Success
)
{
if
(
cuda
HostUnregister
(
const_cast
<
char
*>
(
this
->
mapped
->
data
()))
!=
cuda
Success
)
{
spdlog
::
warn
(
"
hip
HostUnregister failed: {}"
,
hip
GetErrorString
(
hip
GetLastError
()));
spdlog
::
warn
(
"
cuda
HostUnregister failed: {}"
,
cuda
GetErrorString
(
cuda
GetLastError
()));
}
}
}
}
}
}
...
...
src/Serialization.h
View file @
0a7c8614
...
@@ -9,17 +9,17 @@ public:
...
@@ -9,17 +9,17 @@ public:
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CPU
;
this
->
device
.
type
=
Device
::
CPU
;
this
->
ptr
=
ptr
;
this
->
ptr
=
ptr
;
// auto ret =
hip
HostRegister(ptr, size,
hip
HostRegisterPortable |
hip
HostRegisterReadOnly);
// auto ret =
cuda
HostRegister(ptr, size,
cuda
HostRegisterPortable |
cuda
HostRegisterReadOnly);
// if (ret ==
hip
Success) {
// if (ret ==
cuda
Success) {
// this->registered = true;
// this->registered = true;
// } else {
// } else {
// log(std::format("
hip
HostRegister failed at {:p} (size={}): {}", ptr, size,
// log(std::format("
cuda
HostRegister failed at {:p} (size={}): {}", ptr, size,
//
hip
GetErrorString(
hip
GetLastError()))); this->registered = false;
//
cuda
GetErrorString(
cuda
GetLastError()))); this->registered = false;
// }
// }
}
}
virtual
~
BufferMMap
()
{
virtual
~
BufferMMap
()
{
// if (registered) {
// if (registered) {
// checkCUDA(
hip
HostUnregister(ptr));
// checkCUDA(
cuda
HostUnregister(ptr));
// }
// }
}
}
...
...
src/Tensor.h
View file @
0a7c8614
#include "hip/hip_runtime.h"
#pragma once
#pragma once
#include "common.h"
#include "common.h"
...
@@ -75,10 +74,10 @@ public:
...
@@ -75,10 +74,10 @@ public:
BufferHost
(
size_t
size
)
{
BufferHost
(
size_t
size
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CPU
;
this
->
device
.
type
=
Device
::
CPU
;
checkCUDA
(
hip
Host
Ma
lloc
(
&
this
->
ptr
,
size
,
hip
Host
Ma
llocPortable
));
checkCUDA
(
cuda
Host
A
lloc
(
&
this
->
ptr
,
size
,
cuda
Host
A
llocPortable
));
}
}
virtual
~
BufferHost
()
{
virtual
~
BufferHost
()
{
checkCUDA
(
hipHostFree
(
this
->
ptr
));
checkCUDA
(
cudaFreeHost
(
this
->
ptr
));
}
}
};
};
...
@@ -87,20 +86,20 @@ public:
...
@@ -87,20 +86,20 @@ public:
BufferCUDA
(
size_t
size
)
{
BufferCUDA
(
size_t
size
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CUDA
;
this
->
device
.
type
=
Device
::
CUDA
;
// checkCUDA(
hip
GetDevice(&this->device.idx));
// checkCUDA(
cuda
GetDevice(&this->device.idx));
this
->
device
.
idx
=
CUDADeviceContext
::
getDevice
();
this
->
device
.
idx
=
CUDADeviceContext
::
getDevice
();
if
(
size
==
0
)
{
if
(
size
==
0
)
{
this
->
ptr
=
nullptr
;
this
->
ptr
=
nullptr
;
}
}
// TODO: buffer used in multiple streams?
// TODO: buffer used in multiple streams?
checkCUDA
(
hip
MallocAsync
(
&
this
->
ptr
,
size
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
checkCUDA
(
cuda
MallocAsync
(
&
this
->
ptr
,
size
,
getCurrent
CUDA
Stream
()));
}
}
virtual
~
BufferCUDA
()
{
virtual
~
BufferCUDA
()
{
if
(
this
->
size
==
0
)
{
if
(
this
->
size
==
0
)
{
assert
(
!
this
->
ptr
);
assert
(
!
this
->
ptr
);
return
;
return
;
}
}
checkCUDA
(
hip
FreeAsync
(
this
->
ptr
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
checkCUDA
(
cuda
FreeAsync
(
this
->
ptr
,
getCurrent
CUDA
Stream
()));
}
}
virtual
bool
isAsyncBuffer
()
override
{
virtual
bool
isAsyncBuffer
()
override
{
return
true
;
return
true
;
...
@@ -112,11 +111,11 @@ public:
...
@@ -112,11 +111,11 @@ public:
BufferCUDASync
(
size_t
size
)
{
BufferCUDASync
(
size_t
size
)
{
this
->
size
=
size
;
this
->
size
=
size
;
this
->
device
.
type
=
Device
::
CUDA
;
this
->
device
.
type
=
Device
::
CUDA
;
checkCUDA
(
hip
GetDevice
(
&
this
->
device
.
idx
));
checkCUDA
(
cuda
GetDevice
(
&
this
->
device
.
idx
));
checkCUDA
(
hip
Malloc
(
&
this
->
ptr
,
size
));
checkCUDA
(
cuda
Malloc
(
&
this
->
ptr
,
size
));
}
}
virtual
~
BufferCUDASync
()
{
virtual
~
BufferCUDASync
()
{
checkCUDA
(
hip
Free
(
this
->
ptr
));
checkCUDA
(
cuda
Free
(
this
->
ptr
));
}
}
};
};
...
@@ -416,8 +415,8 @@ public:
...
@@ -416,8 +415,8 @@ public:
Tensor
&
zero_
()
{
Tensor
&
zero_
()
{
assert
(
this
->
is_contiguous
());
assert
(
this
->
is_contiguous
());
checkCUDA
(
hip
MemsetAsync
(
checkCUDA
(
cuda
MemsetAsync
(
data_ptr
<
char
>
()
+
shape
.
offset
*
scalar_size
(),
0
,
shape
.
size
()
*
scalar_size
(),
getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
data_ptr
<
char
>
()
+
shape
.
offset
*
scalar_size
(),
0
,
shape
.
size
()
*
scalar_size
(),
getCurrent
CUDA
Stream
()));
return
*
this
;
return
*
this
;
}
}
Tensor
&
copy_
(
Tensor
other
)
{
Tensor
&
copy_
(
Tensor
other
)
{
...
@@ -445,13 +444,13 @@ public:
...
@@ -445,13 +444,13 @@ public:
return
*
this
;
return
*
this
;
}
}
lockBuffer
(
this
->
buffer
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
());
lockBuffer
(
this
->
buffer
,
getCurrent
CUDA
Stream
());
lockBuffer
(
other
.
buffer
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
());
lockBuffer
(
other
.
buffer
,
getCurrent
CUDA
Stream
());
checkCUDA
(
hip
MemcpyAsync
(
data_ptr
<
char
>
(),
checkCUDA
(
cuda
MemcpyAsync
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
(),
shape
.
size
()
*
scalar_size
(),
getCopyKind
(
this
->
device
(),
other
.
device
()),
getCopyKind
(
this
->
device
(),
other
.
device
()),
getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
getCurrent
CUDA
Stream
()));
return
*
this
;
return
*
this
;
}
}
...
@@ -488,7 +487,7 @@ public:
...
@@ -488,7 +487,7 @@ public:
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
}
else
if
(
device
.
type
==
Device
::
CUDA
)
{
CUDADeviceContext
ctx
(
device
.
idx
);
CUDADeviceContext
ctx
(
device
.
idx
);
checkCUDA
(
checkCUDA
(
hip
MemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
cuda
MemsetAsync
(
result
.
buffer
->
getPtr
(),
0xCC
,
result
.
buffer
->
getSize
(),
getCurrent
CUDA
Stream
()));
}
}
}
}
...
@@ -503,7 +502,7 @@ public:
...
@@ -503,7 +502,7 @@ public:
static
Tensor
ones
(
TensorShape
shape
,
ScalarType
scalarType
,
Device
device
)
{
static
Tensor
ones
(
TensorShape
shape
,
ScalarType
scalarType
,
Device
device
)
{
Tensor
result
=
allocate
(
shape
,
scalarType
,
device
);
Tensor
result
=
allocate
(
shape
,
scalarType
,
device
);
// FIXME FIXME FIXME
// FIXME FIXME FIXME
checkCUDA
(
hip
MemsetAsync
(
result
.
buffer
->
getPtr
(),
1
,
result
.
buffer
->
getSize
(),
getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
checkCUDA
(
cuda
MemsetAsync
(
result
.
buffer
->
getPtr
(),
1
,
result
.
buffer
->
getSize
(),
getCurrent
CUDA
Stream
()));
return
result
;
return
result
;
}
}
static
Tensor
static
Tensor
...
@@ -523,18 +522,18 @@ public:
...
@@ -523,18 +522,18 @@ public:
Tensor
result
=
allocate
(
this
->
shape
.
dataExtent
,
this
->
scalarType
,
device
);
Tensor
result
=
allocate
(
this
->
shape
.
dataExtent
,
this
->
scalarType
,
device
);
result
.
copy_
(
*
this
);
result
.
copy_
(
*
this
);
// lockBuffer(this->buffer, getCurrent
HIP
Stream
MasqueradingAsCUDA
());
// lockBuffer(this->buffer, getCurrent
CUDA
Stream());
// lockBuffer(result.buffer, getCurrent
HIP
Stream
MasqueradingAsCUDA
());
// lockBuffer(result.buffer, getCurrent
CUDA
Stream());
// checkCUDA(
hip
MemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
hip
MemcpyDefault,
// checkCUDA(
cuda
MemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
cuda
MemcpyDefault,
// getCurrent
HIP
Stream
MasqueradingAsCUDA
())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// getCurrent
CUDA
Stream())); if (this->device().type == Device::CPU && device.type == Device::CUDA) {
// checkCUDA(
hip
MemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// checkCUDA(
cuda
MemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
//
hip
MemcpyHostToDevice, getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
//
cuda
MemcpyHostToDevice, getCurrent
CUDA
Stream()));
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// } else if (this->device().type == Device::CUDA && device.type == Device::CPU) {
// checkCUDA(
hip
MemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// checkCUDA(
cuda
MemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
//
hip
MemcpyDeviceToHost, getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
//
cuda
MemcpyDeviceToHost, getCurrent
CUDA
Stream()));
// } else {
// } else {
// checkCUDA(
hip
MemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
// checkCUDA(
cuda
MemcpyAsync(result.data_ptr(), this->data_ptr(), result.buffer->getSize(),
//
hip
MemcpyDefault, getCurrent
HIP
Stream
MasqueradingAsCUDA
()));
//
cuda
MemcpyDefault, getCurrent
CUDA
Stream()));
// }
// }
return
result
;
return
result
;
}
}
...
@@ -549,38 +548,38 @@ public:
...
@@ -549,38 +548,38 @@ public:
// auto shapeOut = this->shape;
// auto shapeOut = this->shape;
// shapeOut[dim] = upper_bound - lower_bound;
// shapeOut[dim] = upper_bound - lower_bound;
// assert(dst.shape.data == shapeOut.data);
// assert(dst.shape.data == shapeOut.data);
// checkCUDA(
hip
Memcpy2DAsync(
// checkCUDA(
cuda
Memcpy2DAsync(
// dst.
// dst.
// ));
// ));
// }
// }
private:
private:
static
hip
MemcpyKind
getCopyKind
(
Device
dst
,
Device
src
)
{
static
cuda
MemcpyKind
getCopyKind
(
Device
dst
,
Device
src
)
{
if
(
src
.
type
==
Device
::
CPU
&&
dst
.
type
==
Device
::
CUDA
)
{
if
(
src
.
type
==
Device
::
CPU
&&
dst
.
type
==
Device
::
CUDA
)
{
return
hip
MemcpyHostToDevice
;
return
cuda
MemcpyHostToDevice
;
}
}
if
(
src
.
type
==
Device
::
CUDA
&&
dst
.
type
==
Device
::
CPU
)
{
if
(
src
.
type
==
Device
::
CUDA
&&
dst
.
type
==
Device
::
CPU
)
{
return
hip
MemcpyDeviceToHost
;
return
cuda
MemcpyDeviceToHost
;
}
}
if
(
src
.
type
==
Device
::
CUDA
&&
dst
.
type
==
Device
::
CUDA
)
{
if
(
src
.
type
==
Device
::
CUDA
&&
dst
.
type
==
Device
::
CUDA
)
{
return
hip
MemcpyDeviceToDevice
;
return
cuda
MemcpyDeviceToDevice
;
}
}
if
(
src
.
type
==
Device
::
CPU
&&
dst
.
type
==
Device
::
CPU
)
{
if
(
src
.
type
==
Device
::
CPU
&&
dst
.
type
==
Device
::
CPU
)
{
return
hip
MemcpyHostToHost
;
return
cuda
MemcpyHostToHost
;
}
}
return
hip
MemcpyDefault
;
return
cuda
MemcpyDefault
;
}
}
// static bool isAsyncBuffer(Buffer *buffer) {
// static bool isAsyncBuffer(Buffer *buffer) {
// return dynamic_cast<BufferCUDA *>(buffer);
// return dynamic_cast<BufferCUDA *>(buffer);
// }
// }
static
inline
std
::
map
<
hip
Stream_t
,
std
::
set
<
std
::
shared_ptr
<
Buffer
>>>
lockedBuffers
;
static
inline
std
::
map
<
cuda
Stream_t
,
std
::
set
<
std
::
shared_ptr
<
Buffer
>>>
lockedBuffers
;
public:
public:
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU
// before launching an async operation, make sure to lock the buffer in case the buffer is freed before GPU
// completes
// completes
static
void
lockBuffer
(
std
::
shared_ptr
<
Buffer
>
buffer
,
hip
Stream_t
stream
)
{
static
void
lockBuffer
(
std
::
shared_ptr
<
Buffer
>
buffer
,
cuda
Stream_t
stream
)
{
if
(
!
buffer
->
isAsyncBuffer
())
{
if
(
!
buffer
->
isAsyncBuffer
())
{
lockedBuffers
[
stream
].
insert
(
buffer
);
lockedBuffers
[
stream
].
insert
(
buffer
);
}
}
...
@@ -590,16 +589,16 @@ public:
...
@@ -590,16 +589,16 @@ public:
static
void
unlockBuffers
()
{
static
void
unlockBuffers
()
{
lockedBuffers
.
clear
();
lockedBuffers
.
clear
();
}
}
static
void
unlockBuffers
(
hip
Stream_t
stream
)
{
static
void
unlockBuffers
(
cuda
Stream_t
stream
)
{
lockedBuffers
[
stream
].
clear
();
lockedBuffers
[
stream
].
clear
();
}
}
static
void
synchronizeDevice
()
{
static
void
synchronizeDevice
()
{
checkCUDA
(
hip
DeviceSynchronize
());
checkCUDA
(
cuda
DeviceSynchronize
());
unlockBuffers
();
unlockBuffers
();
}
}
static
void
synchronizeStream
(
hip
Stream_t
stream
)
{
static
void
synchronizeStream
(
cuda
Stream_t
stream
)
{
checkCUDA
(
hip
StreamSynchronize
(
stream
));
checkCUDA
(
cuda
StreamSynchronize
(
stream
));
unlockBuffers
(
stream
);
unlockBuffers
(
stream
);
}
}
};
};
...
...
src/common.h
View file @
0a7c8614
...
@@ -19,47 +19,47 @@
...
@@ -19,47 +19,47 @@
#include <optional>
#include <optional>
#include <chrono>
#include <chrono>
#include <functional>
#include <functional>
#include <
hip/hip
_runtime_api.h>
#include <
cuda
_runtime_api.h>
#include <
hipblas/hip
blas.h>
#include <
cu
blas
_v2
.h>
#include <spdlog/spdlog.h>
#include <spdlog/spdlog.h>
class
CUDAError
:
public
std
::
runtime_error
{
class
CUDAError
:
public
std
::
runtime_error
{
public:
public:
CUDAError
(
hip
Error_t
errorCode
,
std
::
source_location
location
)
CUDAError
(
cuda
Error_t
errorCode
,
std
::
source_location
location
)
:
std
::
runtime_error
(
format
(
errorCode
,
location
)),
errorCode
(
errorCode
),
location
(
location
)
{}
:
std
::
runtime_error
(
format
(
errorCode
,
location
)),
errorCode
(
errorCode
),
location
(
location
)
{}
public:
public:
const
hip
Error_t
errorCode
;
const
cuda
Error_t
errorCode
;
const
std
::
source_location
location
;
const
std
::
source_location
location
;
private:
private:
static
std
::
string
format
(
hip
Error_t
errorCode
,
std
::
source_location
location
)
{
static
std
::
string
format
(
cuda
Error_t
errorCode
,
std
::
source_location
location
)
{
return
spdlog
::
fmt_lib
::
format
(
return
spdlog
::
fmt_lib
::
format
(
"CUDA error: {} (at {}:{})"
,
hip
GetErrorString
(
errorCode
),
location
.
file_name
(),
location
.
line
());
"CUDA error: {} (at {}:{})"
,
cuda
GetErrorString
(
errorCode
),
location
.
file_name
(),
location
.
line
());
}
}
};
};
inline
hip
Error_t
checkCUDA
(
hip
Error_t
retValue
,
inline
cuda
Error_t
checkCUDA
(
cuda
Error_t
retValue
,
const
std
::
source_location
location
=
std
::
source_location
::
current
())
{
const
std
::
source_location
location
=
std
::
source_location
::
current
())
{
if
(
retValue
!=
hip
Success
)
{
if
(
retValue
!=
cuda
Success
)
{
(
void
)
hip
GetLastError
();
(
void
)
cuda
GetLastError
();
throw
CUDAError
(
retValue
,
location
);
throw
CUDAError
(
retValue
,
location
);
}
}
return
retValue
;
return
retValue
;
}
}
inline
hip
blasStatus_t
checkCUBLAS
(
hip
blasStatus_t
retValue
,
inline
cu
blasStatus_t
checkCUBLAS
(
cu
blasStatus_t
retValue
,
const
std
::
source_location
location
=
std
::
source_location
::
current
())
{
const
std
::
source_location
location
=
std
::
source_location
::
current
())
{
if
(
retValue
!=
HIP
BLAS_STATUS_SUCCESS
)
{
if
(
retValue
!=
CU
BLAS_STATUS_SUCCESS
)
{
throw
std
::
runtime_error
(
spdlog
::
fmt_lib
::
format
(
throw
std
::
runtime_error
(
spdlog
::
fmt_lib
::
format
(
"CUBLAS error: {} (at {}:{})"
,
ro
cblas
_s
tatus
_to_s
tring
(
retValue
),
location
.
file_name
(),
location
.
line
()));
"CUBLAS error: {} (at {}:{})"
,
c
u
blas
GetS
tatus
S
tring
(
retValue
),
location
.
file_name
(),
location
.
line
()));
}
}
return
retValue
;
return
retValue
;
}
}
inline
thread_local
std
::
stack
<
hip
Stream_t
>
stackCUDAStreams
;
inline
thread_local
std
::
stack
<
cuda
Stream_t
>
stackCUDAStreams
;
inline
hip
Stream_t
getCurrent
HIP
Stream
MasqueradingAsCUDA
()
{
inline
cuda
Stream_t
getCurrent
CUDA
Stream
()
{
if
(
stackCUDAStreams
.
empty
())
{
if
(
stackCUDAStreams
.
empty
())
{
return
0
;
return
0
;
}
}
...
@@ -67,9 +67,9 @@ inline hipStream_t getCurrentHIPStreamMasqueradingAsCUDA() {
...
@@ -67,9 +67,9 @@ inline hipStream_t getCurrentHIPStreamMasqueradingAsCUDA() {
}
}
struct
CUDAStreamContext
{
struct
CUDAStreamContext
{
hip
Stream_t
stream
;
cuda
Stream_t
stream
;
CUDAStreamContext
(
hip
Stream_t
stream
)
:
stream
(
stream
)
{
CUDAStreamContext
(
cuda
Stream_t
stream
)
:
stream
(
stream
)
{
stackCUDAStreams
.
push
(
stream
);
stackCUDAStreams
.
push
(
stream
);
}
}
CUDAStreamContext
(
const
CUDAStreamContext
&
)
=
delete
;
CUDAStreamContext
(
const
CUDAStreamContext
&
)
=
delete
;
...
@@ -82,30 +82,30 @@ struct CUDAStreamContext {
...
@@ -82,30 +82,30 @@ struct CUDAStreamContext {
};
};
struct
CUDAStreamWrapper
{
struct
CUDAStreamWrapper
{
hip
Stream_t
stream
;
cuda
Stream_t
stream
;
CUDAStreamWrapper
()
{
CUDAStreamWrapper
()
{
checkCUDA
(
hip
StreamCreate
(
&
stream
));
checkCUDA
(
cuda
StreamCreate
(
&
stream
));
}
}
CUDAStreamWrapper
(
const
CUDAStreamWrapper
&
)
=
delete
;
CUDAStreamWrapper
(
const
CUDAStreamWrapper
&
)
=
delete
;
CUDAStreamWrapper
(
CUDAStreamWrapper
&&
)
=
delete
;
CUDAStreamWrapper
(
CUDAStreamWrapper
&&
)
=
delete
;
~
CUDAStreamWrapper
()
{
~
CUDAStreamWrapper
()
{
checkCUDA
(
hip
StreamDestroy
(
stream
));
checkCUDA
(
cuda
StreamDestroy
(
stream
));
}
}
};
};
struct
CUDAEventWrapper
{
struct
CUDAEventWrapper
{
hip
Event_t
event
;
cuda
Event_t
event
;
CUDAEventWrapper
(
unsigned
int
flags
=
hip
EventDefault
)
{
CUDAEventWrapper
(
unsigned
int
flags
=
cuda
EventDefault
)
{
checkCUDA
(
hip
EventCreateWithFlags
(
&
event
,
flags
));
checkCUDA
(
cuda
EventCreateWithFlags
(
&
event
,
flags
));
}
}
CUDAEventWrapper
(
const
CUDAEventWrapper
&
)
=
delete
;
CUDAEventWrapper
(
const
CUDAEventWrapper
&
)
=
delete
;
CUDAEventWrapper
(
CUDAEventWrapper
&&
)
=
delete
;
CUDAEventWrapper
(
CUDAEventWrapper
&&
)
=
delete
;
~
CUDAEventWrapper
()
{
~
CUDAEventWrapper
()
{
checkCUDA
(
hip
EventDestroy
(
event
));
checkCUDA
(
cuda
EventDestroy
(
event
));
}
}
};
};
...
@@ -162,7 +162,7 @@ public:
...
@@ -162,7 +162,7 @@ public:
static
int
getDevice
()
{
static
int
getDevice
()
{
int
idx
=
-
1
;
int
idx
=
-
1
;
if
(
cacheDisabled
()
||
currentDeviceCache
<
0
)
{
if
(
cacheDisabled
()
||
currentDeviceCache
<
0
)
{
checkCUDA
(
hip
GetDevice
(
&
idx
));
checkCUDA
(
cuda
GetDevice
(
&
idx
));
}
else
{
}
else
{
idx
=
currentDeviceCache
;
idx
=
currentDeviceCache
;
}
}
...
@@ -177,7 +177,7 @@ private:
...
@@ -177,7 +177,7 @@ private:
if
(
!
cacheDisabled
()
&&
currentDeviceCache
==
idx
)
{
if
(
!
cacheDisabled
()
&&
currentDeviceCache
==
idx
)
{
return
;
return
;
}
}
checkCUDA
(
hip
SetDevice
(
idx
));
checkCUDA
(
cuda
SetDevice
(
idx
));
currentDeviceCache
=
cacheDisabled
()
?
-
1
:
idx
;
currentDeviceCache
=
cacheDisabled
()
?
-
1
:
idx
;
}
}
...
@@ -190,13 +190,13 @@ private:
...
@@ -190,13 +190,13 @@ private:
}
}
};
};
inline
hip
DeviceProp
_t
*
getCurrentDeviceProperties
()
{
inline
cuda
DeviceProp
*
getCurrentDeviceProperties
()
{
static
thread_local
std
::
map
<
int
,
hip
DeviceProp
_t
>
props
;
static
thread_local
std
::
map
<
int
,
cuda
DeviceProp
>
props
;
int
deviceId
=
CUDADeviceContext
::
getDevice
();
int
deviceId
=
CUDADeviceContext
::
getDevice
();
if
(
!
props
.
contains
(
deviceId
))
{
if
(
!
props
.
contains
(
deviceId
))
{
hip
DeviceProp
_t
prop
;
cuda
DeviceProp
prop
;
checkCUDA
(
hip
GetDeviceProperties
(
&
prop
,
deviceId
));
checkCUDA
(
cuda
GetDeviceProperties
(
&
prop
,
deviceId
));
props
[
deviceId
]
=
prop
;
props
[
deviceId
]
=
prop
;
}
}
return
&
props
.
at
(
deviceId
);
return
&
props
.
at
(
deviceId
);
...
@@ -217,16 +217,16 @@ constexpr int log2Up(T value) {
...
@@ -217,16 +217,16 @@ constexpr int log2Up(T value) {
}
}
struct
CUBLASWrapper
{
struct
CUBLASWrapper
{
hip
blasHandle_t
handle
=
nullptr
;
cu
blasHandle_t
handle
=
nullptr
;
CUBLASWrapper
()
{
CUBLASWrapper
()
{
checkCUBLAS
(
hip
blasCreate
(
&
handle
));
checkCUBLAS
(
cu
blasCreate
(
&
handle
));
}
}
CUBLASWrapper
(
CUBLASWrapper
&&
)
=
delete
;
CUBLASWrapper
(
CUBLASWrapper
&&
)
=
delete
;
CUBLASWrapper
(
const
CUBLASWrapper
&&
)
=
delete
;
CUBLASWrapper
(
const
CUBLASWrapper
&&
)
=
delete
;
~
CUBLASWrapper
()
{
~
CUBLASWrapper
()
{
if
(
handle
)
{
if
(
handle
)
{
checkCUBLAS
(
hip
blasDestroy
(
handle
));
checkCUBLAS
(
cu
blasDestroy
(
handle
));
}
}
}
}
};
};
...
...
src/interop/torch.cpp
View file @
0a7c8614
#include "torch.h"
#include "torch.h"
#include <ATen/
hip/HIP
Context.h>
#include <ATen/
cuda/CUDA
Context.h>
using
spdlog
::
fmt_lib
::
format
;
using
spdlog
::
fmt_lib
::
format
;
...
@@ -37,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
...
@@ -37,7 +37,7 @@ Tensor from_torch(at::Tensor input) {
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
result
.
scalarType
=
mapType
.
at
(
input
.
scalar_type
());
result
.
buffer
=
std
::
make_shared
<
BufferTorchTensor
>
(
std
::
move
(
input
));
result
.
buffer
=
std
::
make_shared
<
BufferTorchTensor
>
(
std
::
move
(
input
));
Tensor
::
lockBuffer
(
result
.
buffer
,
getCurrent
HIP
Stream
MasqueradingAsCUDA
());
Tensor
::
lockBuffer
(
result
.
buffer
,
getCurrent
CUDA
Stream
());
return
result
;
return
result
;
}
}
...
@@ -76,10 +76,10 @@ at::Tensor to_torch(Tensor input) {
...
@@ -76,10 +76,10 @@ at::Tensor to_torch(Tensor input) {
}
}
TorchOpContext
::
TorchOpContext
()
{
TorchOpContext
::
TorchOpContext
()
{
stackCUDAStreams
.
push
(
at
::
hip
::
getCurrent
HIP
Stream
MasqueradingAsCUDA
().
stream
());
stackCUDAStreams
.
push
(
at
::
cuda
::
getCurrent
CUDA
Stream
().
stream
());
}
}
TorchOpContext
::~
TorchOpContext
()
{
TorchOpContext
::~
TorchOpContext
()
{
assert
(
stackCUDAStreams
.
top
()
==
at
::
hip
::
getCurrent
HIP
Stream
MasqueradingAsCUDA
().
stream
());
assert
(
stackCUDAStreams
.
top
()
==
at
::
cuda
::
getCurrent
CUDA
Stream
().
stream
());
stackCUDAStreams
.
pop
();
stackCUDAStreams
.
pop
();
}
}
src/kernels/activation_kernels.
hip
→
src/kernels/activation_kernels.
cu
View file @
0a7c8614
#include "hip/hip_runtime.h"
#include "activation_kernels_impl.cuh"
#include "activation_kernels_impl.cuh"
#include "activation_kernels.h"
#include "activation_kernels.h"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
...
@@ -9,10 +8,10 @@
...
@@ -9,10 +8,10 @@
int num_tokens = input.numel() / d; \
int num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
dim3 block(std::min(d, 1024)); \
const
hip
Stream_t stream = getCurrent
HIP
Stream
MasqueradingAsCUDA
(); \
const
cuda
Stream_t stream = getCurrent
CUDA
Stream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
hipLaunchKernelGGL((
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>>
)
\
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
, dim3(
grid
)
,
dim3(
block
)
, 0, stream
,
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
<<<
grid, block, 0, stream
>>>(
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
});
});
void
silu_and_mul
(
Tensor
&
out
,
// [..., d]
void
silu_and_mul
(
Tensor
&
out
,
// [..., d]
...
@@ -22,14 +21,14 @@ void silu_and_mul(Tensor &out, // [..., d]
...
@@ -22,14 +21,14 @@ void silu_and_mul(Tensor &out, // [..., d]
int
d
=
input
.
size
(
-
1
)
/
2
;
int
d
=
input
.
size
(
-
1
)
/
2
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
dim3
block
(
std
::
min
(
d
,
1024
));
const
hip
Stream_t stream = getCurrent
HIP
Stream
MasqueradingAsCUDA
();
const
cuda
Stream_t
stream
=
getCurrent
CUDA
Stream
();
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// dispatchFloat(input.scalar_type(), [&]<typename scalar_t>() {
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
// });
// });
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"silu_and_mul_kernel"
,
[
&
]
{
hipLaunchKernelGGL((
vllm::silu_and_mul_kernel<scalar_t>
)
vllm
::
silu_and_mul_kernel
<
scalar_t
>
, dim3(
grid
)
,
dim3(
block
)
, 0, stream
,
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d);
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
d
);
});
});
}
}
...
@@ -42,8 +41,8 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
...
@@ -42,8 +41,8 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
int
d
=
input
.
size
(
-
1
)
/
2
;
int
d
=
input
.
size
(
-
1
)
/
2
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
dim3
block
(
std
::
min
(
d
,
1024
));
const
hip
Stream_t stream = getCurrent
HIP
Stream
MasqueradingAsCUDA
();
const
cuda
Stream_t
stream
=
getCurrent
CUDA
Stream
();
hipLaunchKernelGGL((
vllm::dequant_silu_and_mul_quant_kernel<float, false>
), dim3(
grid
)
,
dim3(
block
)
, 0, stream
,
vllm
::
dequant_silu_and_mul_quant_kernel
<
float
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
int8_t
>
(),
input
.
data_ptr
<
int32_t
>
(),
d
,
scale_gate
,
scale_up
,
scale_out
);
out
.
data_ptr
<
int8_t
>
(),
input
.
data_ptr
<
int32_t
>
(),
d
,
scale_gate
,
scale_up
,
scale_out
);
}
}
...
@@ -58,8 +57,8 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
...
@@ -58,8 +57,8 @@ void invoke_dequant_silu_and_mul_quant(Tensor &out, // [..., d]
int
d
=
input
.
size
(
-
1
)
/
2
;
int
d
=
input
.
size
(
-
1
)
/
2
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
d
,
1024
));
dim3
block
(
std
::
min
(
d
,
1024
));
const
hip
Stream_t stream = getCurrent
HIP
Stream
MasqueradingAsCUDA
();
const
cuda
Stream_t
stream
=
getCurrent
CUDA
Stream
();
hipLaunchKernelGGL((
vllm::dequant_silu_and_mul_quant_kernel<float *, true>
), dim3(
grid
)
,
dim3(
block
)
, 0, stream
,
out.data_ptr<int8_t>(),
vllm
::
dequant_silu_and_mul_quant_kernel
<
float
*
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
int8_t
>
(),
input
.
data_ptr
<
int32_t
>
(),
input
.
data_ptr
<
int32_t
>
(),
d
,
d
,
scale_gate
,
scale_gate
,
...
...
src/kernels/activation_kernels_impl.cuh
View file @
0a7c8614
#include "hip/hip_runtime.h"
#include "utils.cuh"
#include "utils.cuh"
#include "reduction_utils.cuh"
#include "reduction_utils.cuh"
...
...
src/kernels/awq/dequantize.cuh
View file @
0a7c8614
...
@@ -11,7 +11,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutl
...
@@ -11,7 +11,7 @@ https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutl
*/
*/
#pragma once
#pragma once
#include <
hip/hip
_fp16.h>
#include <
cuda
_fp16.h>
#include <cstdint>
#include <cstdint>
__forceinline__
__device__
void
dequantize_s4_to_fp16x2
(
half2
const
&
source
,
uint4
*
result
)
{
__forceinline__
__device__
void
dequantize_s4_to_fp16x2
(
half2
const
&
source
,
uint4
*
result
)
{
...
@@ -75,14 +75,14 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
...
@@ -75,14 +75,14 @@ __forceinline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uin
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
}
}
__forceinline__
__device__
void
dequantize_s4_to_fp16x2
(
__
hip
_bfloat162
const
&
source
,
uint4
*
result
)
{
__forceinline__
__device__
void
dequantize_s4_to_fp16x2
(
__
nv
_bfloat162
const
&
source
,
uint4
*
result
)
{
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
// dequantize_s4_to_fp16x2(reinterpret_cast<const half2 &>(source), result);
// *reinterpret_cast<__
hip
_bfloat162 *>(&result->x) = cuda_cast<__
hip
_bfloat162>(*reinterpret_cast<half2
// *reinterpret_cast<__
nv
_bfloat162 *>(&result->x) = cuda_cast<__
nv
_bfloat162>(*reinterpret_cast<half2
// *>(&result->x)); *reinterpret_cast<__
hip
_bfloat162 *>(&result->y) =
// *>(&result->x)); *reinterpret_cast<__
nv
_bfloat162 *>(&result->y) =
// cuda_cast<__
hip
_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__
hip
_bfloat162
// cuda_cast<__
nv
_bfloat162>(*reinterpret_cast<half2 *>(&result->y)); *reinterpret_cast<__
nv
_bfloat162
// *>(&result->z) = cuda_cast<__
hip
_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *>(&result->z) = cuda_cast<__
nv
_bfloat162>(*reinterpret_cast<half2 *>(&result->z));
// *reinterpret_cast<__
hip
_bfloat162 *>(&result->w) = cuda_cast<__
hip
_bfloat162>(*reinterpret_cast<half2
// *reinterpret_cast<__
nv
_bfloat162 *>(&result->w) = cuda_cast<__
nv
_bfloat162>(*reinterpret_cast<half2
// *>(&result->w));
// *>(&result->w));
// return;
// return;
...
...
src/kernels/awq/gemm_awq.
hip
→
src/kernels/awq/gemm_awq.
cu
View file @
0a7c8614
#include "hip/hip_runtime.h"
#include <cuda_fp16.h>
#include <hip/hip_fp16.h>
#include <cuda_bf16.h>
#include <hip/hip_bf16.h>
#include "semaphore.h"
#include "semaphore.h"
#include "gemm_awq.h"
#include "gemm_awq.h"
// #include "../../../nunchaku/csrc/quantization/dequantize.cuh"
// #include "../../../nunchaku/csrc/quantization/dequantize.cuh"
...
@@ -47,8 +46,8 @@
...
@@ -47,8 +46,8 @@
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 num_blocks((num_out_feats + CTA_M - 1) / CTA_M * j_factors1 * SPLITK); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
auto kernel_func = gemm_w4a16_T1<f16_t, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G, SPLITK>; \
hip
FuncSetAttribute(kernel_func,
hip
FuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
cuda
FuncSetAttribute(kernel_func,
cuda
FuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
hipLaunchKernelGGL((
kernel_func
), dim3(
num_blocks
)
,
dim3(
threads_per_block
)
, kSmemByteSize
, 0,
\
kernel_func
<<<
num_blocks, threads_per_block, kSmemByteSize
>>>(
\
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
in_feats, kernel, scales, zeros, out_feats, semaphores, num_in_feats, num_out_channels, num_in_channels);
template
<
int
N
>
template
<
int
N
>
...
@@ -91,8 +90,8 @@ __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
...
@@ -91,8 +90,8 @@ __inline__ __device__ uint32_t cast_smem_ptr_to_uint(void const *const ptr) {
template
<
typename
f16_t
>
template
<
typename
f16_t
>
__inline__
__device__
void
ldmatrix_m8n8_x4_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
__inline__
__device__
void
ldmatrix_m8n8_x4_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __
hip
_bfloat16>::value,
static_assert
(
std
::
is_same
<
f16_t
,
half
>::
value
||
std
::
is_same
<
f16_t
,
__
nv
_bfloat16
>::
value
,
"ldmatrix_m8n8_x4_b16 supports only half or __
hip
_bfloat16 types.");
"ldmatrix_m8n8_x4_b16 supports only half or __
nv
_bfloat16 types."
);
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
"{%0, %1, %2, %3}, [%4];"
:
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
0
]),
:
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
0
]),
...
@@ -104,8 +103,8 @@ __inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, u
...
@@ -104,8 +103,8 @@ __inline__ __device__ void ldmatrix_m8n8_x4_b16(f16_t *shared_warp, int ax0_0, u
template
<
typename
f16_t
>
template
<
typename
f16_t
>
__inline__
__device__
void
ldmatrix_m8n8_x4_trans_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
__inline__
__device__
void
ldmatrix_m8n8_x4_trans_b16
(
f16_t
*
shared_warp
,
int
ax0_0
,
uint32_t
addr
)
{
static_assert(std::is_same<f16_t, half>::value || std::is_same<f16_t, __
hip
_bfloat16>::value,
static_assert
(
std
::
is_same
<
f16_t
,
half
>::
value
||
std
::
is_same
<
f16_t
,
__
nv
_bfloat16
>::
value
,
"ldmatrix_m8n8_x4_trans_b16 supports only half or __
hip
_bfloat16 types.");
"ldmatrix_m8n8_x4_trans_b16 supports only half or __
nv
_bfloat16 types."
);
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
"{%0, %1, %2, %3}, [%4];"
:
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
0
]),
:
"=r"
(((
unsigned
*
)(
shared_warp
+
(
ax0_0
*
8
)))[
0
]),
...
@@ -150,7 +149,7 @@ __device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp
...
@@ -150,7 +149,7 @@ __device__ __inline__ void mma_m16n8k16<half>(float *C_warp, half *A_shared_warp
template
<
>
template
<
>
__device__
__inline__
void
__device__
__inline__
void
mma_m16n8k16<__
hip
_bfloat16>(float *C_warp, __
hip
_bfloat16 *A_shared_warp, __
hip
_bfloat16 *B_shared_warp) {
mma_m16n8k16
<
__
nv
_bfloat16
>
(
float
*
C_warp
,
__
nv
_bfloat16
*
A_shared_warp
,
__
nv
_bfloat16
*
B_shared_warp
)
{
asm
volatile
(
asm
volatile
(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
...
@@ -379,7 +378,7 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
...
@@ -379,7 +378,7 @@ __global__ void gemm_w4a16_T1(f16_t *__restrict__ A,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
)
{
int
K
)
{
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ < 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ < 800
trap_unsupported_arch
();
trap_unsupported_arch
();
return
;
return
;
#endif
#endif
...
@@ -945,7 +944,7 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
...
@@ -945,7 +944,7 @@ __global__ void gemm_w4a16_T2(f16_t *__restrict__ A,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
)
{
int
K
)
{
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ < 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ < 800
trap_unsupported_arch
();
trap_unsupported_arch
();
return
;
return
;
#endif
#endif
...
@@ -1278,12 +1277,12 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
...
@@ -1278,12 +1277,12 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
dim3
num_blocks
((
num_out_feats
+
CTA_M
-
1
)
/
CTA_M
*
j_factors1
);
dim3
num_blocks
((
num_out_feats
+
CTA_M
-
1
)
/
CTA_M
*
j_factors1
);
dim3
threads_per_block
(
WARP_SIZE
,
NUM_WARPS
);
dim3
threads_per_block
(
WARP_SIZE
,
NUM_WARPS
);
auto
kernel_func
=
gemm_w4a16_T2
<
f16_t
,
CTA_M
,
CTA_N
,
CTA_K
,
WARP_M
,
WARP_N
,
WARP_K
,
STAGES
,
G
>
;
auto
kernel_func
=
gemm_w4a16_T2
<
f16_t
,
CTA_M
,
CTA_N
,
CTA_K
,
WARP_M
,
WARP_N
,
WARP_K
,
STAGES
,
G
>
;
hip
FuncSetAttribute(kernel_func,
hip
FuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
cuda
FuncSetAttribute
(
kernel_func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
kSmemByteSize
);
hipLaunchKernelGGL((
kernel_func
), dim3(
num_blocks
)
,
dim3(
threads_per_block
)
, kSmemByteSize
, 0,
kernel_func
<<<
num_blocks
,
threads_per_block
,
kSmemByteSize
>>>
(
in_feats
,
kernel
,
scales
,
zeros
,
out_feats
,
num_in_feats
,
num_out_channels
,
num_in_channels
);
in_feats
,
kernel
,
scales
,
zeros
,
out_feats
,
num_in_feats
,
num_out_channels
,
num_in_channels
);
}
}
}
else
if
(
_in_feats
.
scalar_type
()
==
Tensor
::
BF16
)
{
}
else
if
(
_in_feats
.
scalar_type
()
==
Tensor
::
BF16
)
{
using f16_t = __
hip
_bfloat16;
using
f16_t
=
__
nv
_bfloat16
;
auto
in_feats
=
reinterpret_cast
<
f16_t
*>
(
_in_feats
.
data_ptr
());
auto
in_feats
=
reinterpret_cast
<
f16_t
*>
(
_in_feats
.
data_ptr
());
auto
kernel
=
reinterpret_cast
<
f16_t
*>
(
_kernel
.
data_ptr
<
int16_t
>
());
auto
kernel
=
reinterpret_cast
<
f16_t
*>
(
_kernel
.
data_ptr
<
int16_t
>
());
...
@@ -1358,8 +1357,8 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
...
@@ -1358,8 +1357,8 @@ Tensor awq_gemm_forward_cuda(Tensor _in_feats, Tensor _kernel, Tensor _scales, T
dim3
num_blocks
((
num_out_feats
+
CTA_M
-
1
)
/
CTA_M
*
j_factors1
);
dim3
num_blocks
((
num_out_feats
+
CTA_M
-
1
)
/
CTA_M
*
j_factors1
);
dim3
threads_per_block
(
WARP_SIZE
,
NUM_WARPS
);
dim3
threads_per_block
(
WARP_SIZE
,
NUM_WARPS
);
auto
kernel_func
=
gemm_w4a16_T2
<
f16_t
,
CTA_M
,
CTA_N
,
CTA_K
,
WARP_M
,
WARP_N
,
WARP_K
,
STAGES
,
G
>
;
auto
kernel_func
=
gemm_w4a16_T2
<
f16_t
,
CTA_M
,
CTA_N
,
CTA_K
,
WARP_M
,
WARP_N
,
WARP_K
,
STAGES
,
G
>
;
hip
FuncSetAttribute(kernel_func,
hip
FuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize);
cuda
FuncSetAttribute
(
kernel_func
,
cuda
FuncAttributeMaxDynamicSharedMemorySize
,
kSmemByteSize
);
hipLaunchKernelGGL((
kernel_func
), dim3(
num_blocks
)
,
dim3(
threads_per_block
)
, kSmemByteSize
, 0,
kernel_func
<<<
num_blocks
,
threads_per_block
,
kSmemByteSize
>>>
(
in_feats
,
kernel
,
scales
,
zeros
,
out_feats
,
num_in_feats
,
num_out_channels
,
num_in_channels
);
in_feats
,
kernel
,
scales
,
zeros
,
out_feats
,
num_in_feats
,
num_out_channels
,
num_in_channels
);
}
}
}
else
{
}
else
{
...
...
src/kernels/awq/gemv_awq.
hip
→
src/kernels/awq/gemv_awq.
cu
View file @
0a7c8614
#include "hip/hip_runtime.h"
/*
/*
* Modified from NVIDIA
* Modified from NVIDIA
* [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
...
@@ -31,8 +30,8 @@
...
@@ -31,8 +30,8 @@
#include "../utils.cuh"
#include "../utils.cuh"
#include <
hip/hip
_fp16.h>
#include <
cuda
_fp16.h>
#include <
hip/hip
_bf16.h>
#include <
cuda
_bf16.h>
#include <stdio.h>
#include <stdio.h>
#include "dequantize.cuh"
#include "dequantize.cuh"
...
@@ -81,7 +80,7 @@ __device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
...
@@ -81,7 +80,7 @@ __device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
}
}
template
<
>
template
<
>
__device__ __forceinline__ packed_as<__
hip
_bfloat16, 2>::type half2half2<__
hip
_bfloat16>(__
hip
_bfloat16 x) {
__device__
__forceinline__
packed_as
<
__
nv
_bfloat16
,
2
>::
type
half2half2
<
__
nv
_bfloat16
>
(
__
nv
_bfloat16
x
)
{
return
__bfloat162bfloat162
(
x
);
return
__bfloat162bfloat162
(
x
);
}
}
...
@@ -94,7 +93,7 @@ __device__ __forceinline__ float2 half22float2<half2>(half2 val) {
...
@@ -94,7 +93,7 @@ __device__ __forceinline__ float2 half22float2<half2>(half2 val) {
}
}
template
<
>
template
<
>
__device__ __forceinline__ float2 half22float2<__
hip
_bfloat162>(__
hip
_bfloat162 val) {
__device__
__forceinline__
float2
half22float2
<
__
nv
_bfloat162
>
(
__
nv
_bfloat162
val
)
{
return
__bfloat1622float2
(
val
);
return
__bfloat1622float2
(
val
);
}
}
...
@@ -107,8 +106,8 @@ __global__ void gemv_kernel(const half_t *inputs,
...
@@ -107,8 +106,8 @@ __global__ void gemv_kernel(const half_t *inputs,
const
int
IC
,
const
int
IC
,
const
int
OC
)
{
const
int
OC
)
{
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ < 800
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ < 800
if constexpr (std::is_same_v<half_t, __
hip
_bfloat16>) {
if
constexpr
(
std
::
is_same_v
<
half_t
,
__
nv
_bfloat16
>
)
{
trap_unsupported_arch
();
trap_unsupported_arch
();
return
;
return
;
}
}
...
@@ -283,10 +282,10 @@ Tensor gemv_awq(
...
@@ -283,10 +282,10 @@ Tensor gemv_awq(
return
;
return
;
}
}
if
constexpr
(
M
>
0
)
{
if
constexpr
(
M
>
0
)
{
hipLaunchKernelGGL((
gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>
)
gemv_kernel
<
half_t
,
N_PER_BLOCK
,
M
,
BLOCK_SIZE
,
GROUP_SIZE
>
, dim3(
num_blocks
)
,
dim3(
num_threads
)
, 0, getCurrent
HIP
Stream
MasqueradingAsCUDA(),
<<<
num_blocks
,
num_threads
,
0
,
getCurrent
CUDA
Stream
()
>>>
(
in_feats
,
kernel
,
scaling_factors
,
zeros
,
out_feats
,
k
,
n
);
in_feats
,
kernel
,
scaling_factors
,
zeros
,
out_feats
,
k
,
n
);
checkCUDA(
hip
GetLastError());
checkCUDA
(
cuda
GetLastError
());
}
}
});
});
...
...
src/kernels/awq/semaphore.h
View file @
0a7c8614
#include "hip/hip_runtime.h"
/***************************************************************************************************
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* SPDX-License-Identifier: BSD-3-Clause
...
@@ -56,7 +55,7 @@ public:
...
@@ -56,7 +55,7 @@ public:
/// Permit fetching the synchronization mechanism early
/// Permit fetching the synchronization mechanism early
__device__
void
fetch
()
{
__device__
void
fetch
()
{
if
(
wait_thread
)
{
if
(
wait_thread
)
{
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 700
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 700
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
asm
volatile
(
"ld.global.acquire.gpu.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
#else
#else
asm
volatile
(
"ld.global.cg.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
asm
volatile
(
"ld.global.cg.b32 %0, [%1];
\n
"
:
"=r"
(
state
)
:
"l"
(
lock
));
...
@@ -83,7 +82,7 @@ public:
...
@@ -83,7 +82,7 @@ public:
__syncthreads
();
__syncthreads
();
if
(
wait_thread
)
{
if
(
wait_thread
)
{
#if defined(__
DTK
_ARCH__) && __
DTK
_ARCH__ >= 700
#if defined(__
CUDA
_ARCH__) && __
CUDA
_ARCH__ >= 700
asm
volatile
(
"st.global.release.gpu.b32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
status
));
asm
volatile
(
"st.global.release.gpu.b32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
status
));
#else
#else
asm
volatile
(
"st.global.cg.b32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
status
));
asm
volatile
(
"st.global.cg.b32 [%0], %1;
\n
"
:
:
"l"
(
lock
),
"r"
(
status
));
...
...
src/kernels/dispatch_utils.h
View file @
0a7c8614
...
@@ -2,13 +2,13 @@
...
@@ -2,13 +2,13 @@
#include "common.h"
#include "common.h"
#include "Tensor.h"
#include "Tensor.h"
#include <
hip/hip
_fp16.h>
#include <
cuda
_fp16.h>
template
<
typename
F
>
template
<
typename
F
>
inline
auto
dispatchFloat
(
Tensor
::
ScalarType
scalarType
,
F
&&
func
)
{
inline
auto
dispatchFloat
(
Tensor
::
ScalarType
scalarType
,
F
&&
func
)
{
switch
(
scalarType
)
{
switch
(
scalarType
)
{
case
Tensor
::
BF16
:
case
Tensor
::
BF16
:
return
func
.
template
operator
()
<
__
hip
_bfloat16
>();
return
func
.
template
operator
()
<
__
nv
_bfloat16
>();
case
Tensor
::
FP16
:
case
Tensor
::
FP16
:
return
func
.
template
operator
()
<
half
>();
return
func
.
template
operator
()
<
half
>();
case
Tensor
::
FP32
:
case
Tensor
::
FP32
:
...
@@ -23,7 +23,7 @@ template<typename F>
...
@@ -23,7 +23,7 @@ template<typename F>
inline
auto
dispatchFloat16
(
Tensor
::
ScalarType
scalarType
,
F
&&
func
)
{
inline
auto
dispatchFloat16
(
Tensor
::
ScalarType
scalarType
,
F
&&
func
)
{
switch
(
scalarType
)
{
switch
(
scalarType
)
{
case
Tensor
::
BF16
:
case
Tensor
::
BF16
:
return
func
.
template
operator
()
<
__
hip
_bfloat16
>();
return
func
.
template
operator
()
<
__
nv
_bfloat16
>();
case
Tensor
::
FP16
:
case
Tensor
::
FP16
:
return
func
.
template
operator
()
<
half
>();
return
func
.
template
operator
()
<
half
>();
default:
default:
...
@@ -36,7 +36,7 @@ template<typename F>
...
@@ -36,7 +36,7 @@ template<typename F>
inline
auto
dispatch
(
Tensor
::
ScalarType
scalarType
,
F
&&
func
)
{
inline
auto
dispatch
(
Tensor
::
ScalarType
scalarType
,
F
&&
func
)
{
switch
(
scalarType
)
{
switch
(
scalarType
)
{
case
Tensor
::
BF16
:
case
Tensor
::
BF16
:
return
func
.
template
operator
()
<
__
hip
_bfloat16
>();
return
func
.
template
operator
()
<
__
nv
_bfloat16
>();
case
Tensor
::
FP16
:
case
Tensor
::
FP16
:
return
func
.
template
operator
()
<
half
>();
return
func
.
template
operator
()
<
half
>();
case
Tensor
::
FP32
:
case
Tensor
::
FP32
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment