Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ccaea50e
"megatron/legacy/model/gpt_model.py" did not exist on "5942af978a8a8ff706a302b1ba2d9ef3ce144444"
Commit
ccaea50e
authored
Mar 08, 2024
by
Jing Zhang
Browse files
merge navi31_rel
parents
0b914465
10127959
Changes
126
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
76 additions
and
58 deletions
+76
-58
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
...device/impl/device_multi_query_attention_forward_wmma.hpp
+4
-7
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+29
-24
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+5
-5
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+36
-20
profiler/include/profiler/profile_elementwise_layernorm_impl.hpp
...r/include/profiler/profile_elementwise_layernorm_impl.hpp
+1
-1
test/grouped_convnd_bwd_data/CMakeLists.txt
test/grouped_convnd_bwd_data/CMakeLists.txt
+1
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
View file @
ccaea50e
...
...
@@ -60,8 +60,7 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// clang-format off
// ***************************************************
...
...
@@ -168,7 +167,7 @@ __global__ void
ignore
=
G1
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
// Computes C = A * B0 * B1
...
...
@@ -595,8 +594,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
if
(
ck
::
is_navi3_supported
())
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
@@ -952,8 +950,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
#if 0
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
ck::get_device_name() == "gfx1102")
if(ck::is_navi3_supported())
{
if constexpr(!(is_same_v<Acc0DataType, float> || is_same_v<Acc0DataType, int32_t>))
{
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
ccaea50e
...
...
@@ -24,10 +24,10 @@ struct BlockToCTileMap_M00_N0_M01
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
1
)
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
1
)
:
M01_
(
M01
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
))
{
}
...
...
@@ -51,8 +51,8 @@ struct BlockToCTileMap_M00_N0_M01
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
DefaultValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
...
...
@@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01
return
true
;
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
...
...
@@ -120,18 +120,19 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
const
expr
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
#if 0
...
...
@@ -142,8 +143,9 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
...
...
@@ -164,7 +166,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
...
...
@@ -237,8 +239,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
...
...
@@ -616,7 +618,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
index_t
M01_
;
...
...
@@ -674,7 +679,7 @@ struct BlockToCTileMap_M00_N00_M01_N01
return
true
;
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
...
...
@@ -786,7 +791,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return
true
;
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
...
...
@@ -910,7 +915,7 @@ struct OffsettedBlockToCTileMap
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
...
...
@@ -967,7 +972,7 @@ struct BlockToCTileMap_3DGrid_KSplit
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ccaea50e
...
...
@@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
BGridDesc_N_K
&
b_grid_desc_n_k
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
...
@@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// check block-to-E-tile
if
(
!
block_2_etile_map
.
CheckValidity
(
e_grid_desc_m_n
))
{
return
false
;
}
//
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
//
{
//
return false;
//
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// check tensor size: cannot be larger than 2GB each
...
...
include/ck/utility/type_convert.hpp
View file @
ccaea50e
...
...
@@ -164,21 +164,24 @@ __host__ __device__ constexpr Y f8_convert_sr(X x);
template
<
>
inline
__host__
__device__
f8_t
f8_convert_sr
<
f8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
const
float
max_fp8
=
240.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_fp8
,
-
max_fp8
);
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
...
...
@@ -201,7 +204,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
f8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
...
...
@@ -213,7 +216,7 @@ inline __host__ __device__ f8_t f8_convert_sr<f8_t, half_t>(half_t x)
template
<
>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
#if defined(__gfx94__)
union
...
...
@@ -222,10 +225,15 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, float>(float x)
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
const
float
max_bf8
=
57344.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_bf8
,
-
max_bf8
);
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
// little endian
#else
constexpr
bool
negative_zero_nan
=
true
;
...
...
@@ -248,7 +256,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr
bool
negative_zero_nan
=
true
;
constexpr
bool
clip
=
true
;
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
constexpr
int
seed
=
1254739
;
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
...
...
@@ -265,16 +273,19 @@ template <>
inline
__host__
__device__
f8_t
f8_convert_rne
<
f8_t
,
float
>
(
float
x
)
{
#if defined(__gfx94__)
float
max_fp8
=
240.0
f
;
x
=
x
>
max_fp8
?
max_fp8
:
(
x
<
-
max_fp8
?
-
max_fp8
:
x
);
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
const
float
max_fp8
=
240.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_fp8
,
-
max_fp8
);
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
...
...
@@ -318,8 +329,13 @@ inline __host__ __device__ bf8_t f8_convert_rne<bf8_t, float>(float x)
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// not endian independent
}
val
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
val
.
fval
=
x
;
uint32_t
ival
=
0
;
const
float
max_bf8
=
57344.0
f
;
// if x is not +/- infinity or nan
if
((
val
.
i32val
&
NumericUtils
<
float
>::
nan_mask
)
!=
NumericUtils
<
float
>::
Inf
)
// clip float value
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
max_bf8
,
-
max_bf8
);
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
return
val
.
i8val
[
0
];
...
...
profiler/include/profiler/profile_elementwise_layernorm_impl.hpp
View file @
ccaea50e
...
...
@@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification,
y_dev
.
FromDevice
(
y
.
mData
.
data
());
bool
pass
=
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
1
e-3
,
1
e-3
);
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
5
e-3
,
5
e-3
);
if
(
do_log
)
{
...
...
test/grouped_convnd_bwd_data/CMakeLists.txt
View file @
ccaea50e
list
(
APPEND gpu_list_xdl gfx908 gfx90a gfx940
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
)
list
(
APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102
gfx1103
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list_xdl AND target EQUAL 0
)
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment