Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b2a49620
Commit
b2a49620
authored
May 16, 2023
by
carlushuang
Browse files
shrink karg for streamk
parent
fcb2911e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
92 additions
and
88 deletions
+92
-88
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
...sor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
+15
-2
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+23
-26
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+53
-59
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp
View file @
b2a49620
...
@@ -143,8 +143,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
...
@@ -143,8 +143,21 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
// TODO: remove clear buffer for streamk kernels
// TODO: remove clear buffer for streamk kernels
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
ave_time
=
ave_time
=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dims
,
dim3
(
BlockSize
),
0
,
karg
);
kernel
,
grid_dims
,
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
M
,
karg
.
N
,
karg
.
K
,
karg
.
StrideA
,
karg
.
StrideB
,
karg
.
StrideC
,
karg
.
block_mapping
);
return
ave_time
;
return
ave_time
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
b2a49620
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include <limits>
#include <limits>
#include <stdlib.h>
namespace
ck
{
namespace
ck
{
...
@@ -635,14 +636,18 @@ struct BlockToCTileMap_3DGrid_KSplit
...
@@ -635,14 +636,18 @@ struct BlockToCTileMap_3DGrid_KSplit
return
true
;
return
true
;
}
}
};
};
#include <stdlib.h>
template
<
uint32_t
MPerBlock_
,
uint32_t
NPerBlock_
,
uint32_t
KPerBlock_
>
template
<
uint32_t
MPerBlock_
,
uint32_t
NPerBlock_
,
uint32_t
KPerBlock_
,
uint32_t
TileSwizzleSubM_
=
8
>
struct
BlockToCTileMap_GemmStreamK
struct
BlockToCTileMap_GemmStreamK
{
{
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
2
;
static
constexpr
uint32_t
min_k_iters_per_sk_block
=
2
;
static
constexpr
uint32_t
MPerBlock
=
MPerBlock_
;
static
constexpr
uint32_t
MPerBlock
=
MPerBlock_
;
static
constexpr
uint32_t
NPerBlock
=
NPerBlock_
;
static
constexpr
uint32_t
NPerBlock
=
NPerBlock_
;
static
constexpr
uint32_t
KPerBlock
=
KPerBlock_
;
static
constexpr
uint32_t
KPerBlock
=
KPerBlock_
;
static
constexpr
uint32_t
tile_swizzle_sub_m
=
TileSwizzleSubM_
;
//--------------------------------------
//--------------------------------------
// pass to device
// pass to device
...
@@ -657,8 +662,8 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -657,8 +662,8 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
k_iters_per_big_block
;
uint32_t
k_iters_per_big_block
;
MDiv
k_iters_per_tile
;
MDiv
k_iters_per_tile
;
MDiv
n_tiles
;
MDiv
n_tiles
;
MDiv
tile_swizzle_sub_m
;
MDiv
tile_swizzle_sub_m_rem
;
//
MDiv tile_swizzle_sub_m_rem;
//--------------------------------------
//--------------------------------------
static
int
env_get_int
(
const
char
*
var_name
,
int
default_int
)
static
int
env_get_int
(
const
char
*
var_name
,
int
default_int
)
...
@@ -676,8 +681,7 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -676,8 +681,7 @@ struct BlockToCTileMap_GemmStreamK
uint32_t
k
,
uint32_t
k
,
uint32_t
num_cu
,
uint32_t
num_cu
,
uint32_t
occupancy
,
uint32_t
occupancy
,
uint32_t
sk_blocks
=
0xffffffff
,
uint32_t
sk_blocks
=
0xffffffff
)
uint32_t
tile_swizzle_sub_m_factor
=
8
)
{
{
uint32_t
num_tiles
=
uint32_t
num_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
*
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
*
math
::
integer_divide_ceil
(
n
,
NPerBlock
);
...
@@ -723,15 +727,8 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -723,15 +727,8 @@ struct BlockToCTileMap_GemmStreamK
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
sk_tiles
=
partial_dispatche_tiles
+
num_cu
;
}
}
// dp_num_blocks = dp_tiles;
// dp_start_block_idx = num_cu * sk_occupancy;
dp_iters_per_block
=
k_iters_per_tile
.
get
();
dp_iters_per_block
=
k_iters_per_tile
.
get
();
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
sk_total_iters
=
k_iters_per_tile
.
get
()
*
sk_tiles
;
// printf("num_tiles:%d, full_dispatches:%d, full_dispatch_tiles:%d,
// partial_dispatche_tiles:%d\n",
// num_tiles, full_dispatches, full_dispatch_tiles, partial_dispatche_tiles);
{
{
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
uint32_t
min_sk_tiles
=
(
sk_tiles
>=
num_cu
)
?
num_cu
:
(
sk_tiles
+
1
);
...
@@ -812,11 +809,8 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -812,11 +809,8 @@ struct BlockToCTileMap_GemmStreamK
}
}
n_tiles
=
MDiv
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
n_tiles
=
MDiv
(
math
::
integer_divide_ceil
(
n
,
NPerBlock
));
tile_swizzle_sub_m_factor
=
// tile_swizzle_sub_m_rem =
env_get_int
(
"tile_swizzle_sub_m_factor"
,
tile_swizzle_sub_m_factor
);
// MDiv(math::integer_divide_ceil(m, MPerBlock) % tile_swizzle_sub_m);
tile_swizzle_sub_m
=
MDiv
(
tile_swizzle_sub_m_factor
);
tile_swizzle_sub_m_rem
=
MDiv
(
math
::
integer_divide_ceil
(
m
,
MPerBlock
)
%
tile_swizzle_sub_m_factor
);
printf
(
"cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
printf
(
"cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_num_blocks:%d, "
...
@@ -896,22 +890,25 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -896,22 +890,25 @@ struct BlockToCTileMap_GemmStreamK
// swizzle tile
// swizzle tile
uint32_t
m_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
);
uint32_t
m_tiles
=
math
::
integer_divide_ceil
(
m
,
MPerBlock
);
// uint32_t n_tiles = math::integer_divide_ceil(n, NPerBlock);
uint32_t
quo_sub_m
,
rem_sub_m
;
uint32_t
tile_swizzle_sub_m_rem
=
m_tiles
%
tile_swizzle_sub_m
;
tile_swizzle_sub_m
.
divmod
(
m_tile_idx
,
quo_sub_m
,
rem_sub_m
);
const
auto
sub_m_adapt
=
(
m_tile_idx
<
(
m_tiles
-
tile_swizzle_sub_m_rem
.
get
()
))
const
auto
sub_m_adapt
=
(
m_tile_idx
<
(
m_tiles
-
tile_swizzle_sub_m_rem
))
?
tile_swizzle_sub_m
?
tile_swizzle_sub_m
:
tile_swizzle_sub_m_rem
;
:
tile_swizzle_sub_m_rem
;
uint32_t
m_tile_idx_sub0
,
m_tile_idx_sub1
;
uint32_t
m_tile_idx_sub0
,
m_tile_idx_sub1
;
tile_swizzle_sub_m
.
divmod
(
m_tile_idx
,
m_tile_idx_sub0
,
m_tile_idx_sub1
);
m_tile_idx_sub0
=
m_tile_idx
/
tile_swizzle_sub_m
;
m_tile_idx_sub1
=
m_tile_idx
%
tile_swizzle_sub_m
;
uint32_t
tile_idx_local
=
n_tile_idx
+
m_tile_idx_sub1
*
n_tiles
.
get
();
uint32_t
tile_idx_local
=
n_tile_idx
+
m_tile_idx_sub1
*
n_tiles
.
get
();
uint32_t
m_tile_idx_with_adapt
,
n_tile_idx_with_adapt
;
uint32_t
m_tile_idx_with_adapt
,
n_tile_idx_with_adapt
;
sub_m_adapt
.
divmod
(
tile_idx_local
,
n_tile_idx_with_adapt
,
m_tile_idx_with_adapt
);
return
make_tuple
(
m_tile_idx_with_adapt
+
m_tile_idx_sub0
*
tile_swizzle_sub_m
.
get
(),
n_tile_idx_with_adapt
=
tile_idx_local
/
sub_m_adapt
;
m_tile_idx_with_adapt
=
tile_idx_local
%
sub_m_adapt
;
// sub_m_adapt.divmod(tile_idx_local, n_tile_idx_with_adapt, m_tile_idx_with_adapt);
return
make_tuple
(
m_tile_idx_with_adapt
+
m_tile_idx_sub0
*
tile_swizzle_sub_m
,
n_tile_idx_with_adapt
);
n_tile_idx_with_adapt
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
View file @
b2a49620
...
@@ -24,14 +24,36 @@ __global__ void
...
@@ -24,14 +24,36 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdlops_streamk
(
typename
GridwiseGemm
::
Argument
karg
)
// kernel_gemm_xdlops_streamk(typename GridwiseGemm::Argument karg)
kernel_gemm_xdlops_streamk
(
const
typename
GridwiseGemm
::
FloatAB
*
p_a_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
p_b_grid
,
typename
GridwiseGemm
::
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
typename
GridwiseGemm
::
Block2CTileMap
block_mapping
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
Run
(
karg
,
static_cast
<
void
*>
(
p_shared
));
// GridwiseGemm::Run(karg, static_cast<void*>(p_shared));
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
block_mapping
,
static_cast
<
void
*>
(
p_shared
));
#else
#else
ignore
=
karg
;
ignore
=
karg
;
ignore
=
b2c_map
;
ignore
=
b2c_map
;
...
@@ -40,9 +62,9 @@ __global__ void
...
@@ -40,9 +62,9 @@ __global__ void
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
Block2CTileMap_
,
typename
Block2CTileMap_
,
typename
FloatAB
,
typename
FloatAB
_
,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
_
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
...
@@ -95,8 +117,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -95,8 +117,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static
constexpr
auto
KPerBlock
=
K0PerBlock
*
K1
;
static
constexpr
auto
KPerBlock
=
K0PerBlock
*
K1
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
FloatCShuffle
=
FloatAcc
;
using
Block2CTileMap
=
Block2CTileMap_
;
using
Block2CTileMap
=
Block2CTileMap_
;
using
FloatAB
=
FloatAB_
;
using
FloatC
=
FloatC_
;
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
{
...
@@ -154,31 +179,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -154,31 +179,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
math
::
integer_divide_ceil
(
karg
.
M
,
MPerBlock
),
math
::
integer_divide_ceil
(
karg
.
M
,
MPerBlock
),
karg
.
k_batch
);
karg
.
k_batch
);
}
}
#if 0
// prefer this to be called on host
__host__ __device__ static auto CalculateMPadded(index_t M)
{
return (M + MPerBlock - 1) / MPerBlock * MPerBlock;
}
__host__ __device__ static auto CalculateNPadded(index_t N)
{
return (N + NPerBlock - 1) / NPerBlock * NPerBlock;
}
__host__ __device__ static auto CalculateK0(index_t K, index_t K_Batch = 1)
{
// k_batch * k0 * k0_per_block * k1
auto K_t = K_Batch * K0PerBlock * K1;
return (K + K_t - 1) / K_t * K0PerBlock;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K0 = CalculateK0(K, K_Batch);
return K_Batch * K0 * K1;
}
#endif
__host__
__device__
static
auto
CalculateK0
(
index_t
KPad
)
{
return
KPad
/
K1
;
}
__host__
__device__
static
auto
CalculateK0
(
index_t
KPad
)
{
return
KPad
/
K1
;
}
...
@@ -296,7 +296,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -296,7 +296,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatC
));
c_block_size
*
sizeof
(
FloatC
Shuffle
));
}
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
...
@@ -384,30 +384,29 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -384,30 +384,29 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
}
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
// __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap()
// {
// return BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>();
// }
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
))
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
,
1
,
1
))
>
;
// using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap())>;
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
)
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
Block2CTileMap
block_mapping
,
void
*
__restrict__
p_shared_block
)
{
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
uint32_t
m
=
M
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
uint32_t
n
=
N
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
uint32_t
k
=
K
;
uint32_t
m
=
karg
.
M
;
uint32_t
n
=
karg
.
N
;
uint32_t
k
=
karg
.
K
;
uint32_t
pad_m
=
(
m
+
MPerBlock
-
1
)
/
MPerBlock
*
MPerBlock
;
uint32_t
pad_m
=
(
m
+
MPerBlock
-
1
)
/
MPerBlock
*
MPerBlock
;
uint32_t
pad_n
=
(
n
+
NPerBlock
-
1
)
/
NPerBlock
*
NPerBlock
;
uint32_t
pad_n
=
(
n
+
NPerBlock
-
1
)
/
NPerBlock
*
NPerBlock
;
uint32_t
pad_k
=
(
k
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
uint32_t
pad_k
=
(
k
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
uint32_t
stride_a
=
karg
.
StrideA
;
uint32_t
stride_a
=
StrideA
;
uint32_t
stride_b
=
karg
.
StrideB
;
uint32_t
stride_b
=
StrideB
;
uint32_t
stride_c
=
karg
.
StrideC
;
uint32_t
stride_c
=
StrideC
;
const
auto
a_k0_m_k1_grid_desc
=
MakeAGridDescriptor_K0_M_K1
(
m
,
pad_m
,
k
,
pad_k
,
stride_a
);
const
auto
a_k0_m_k1_grid_desc
=
MakeAGridDescriptor_K0_M_K1
(
m
,
pad_m
,
k
,
pad_k
,
stride_a
);
const
auto
b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_K0_N_K1
(
k
,
pad_k
,
n
,
pad_n
,
stride_b
);
const
auto
b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_K0_N_K1
(
k
,
pad_k
,
n
,
pad_n
,
stride_b
);
...
@@ -467,10 +466,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -467,10 +466,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v3
();
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v3
();
auto
&
block_mapping
=
karg
.
block_mapping
;
uint32_t
block_idx
=
block_mapping
.
get_block_idx
();
uint32_t
block_idx
=
block_mapping
.
get_block_idx
();
bool
is_sk_block
=
block_idx
<
block_mapping
.
sk_num_blocks
;
bool
is_sk_block
=
block_idx
<
block_mapping
.
sk_num_blocks
;
bool
is_dp_block
=
block_idx
>=
block_mapping
.
dp_start_block_idx
;
bool
is_dp_block
=
block_idx
>=
block_mapping
.
dp_start_block_idx
;
uint32_t
iter_start
,
iter_end
;
uint32_t
iter_start
,
iter_end
;
block_mapping
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
block_mapping
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
uint32_t
total_iter_length
=
iter_end
-
iter_start
;
...
@@ -608,7 +606,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -608,7 +606,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared_block
),
static_cast
<
FloatC
Shuffle
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
...
@@ -662,7 +660,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -662,7 +660,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// VGPR to LDS
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatAcc
,
FloatC
,
FloatC
Shuffle
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -701,7 +699,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -701,7 +699,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename SrcData,
FloatC
Shuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
...
@@ -807,10 +805,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -807,10 +805,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// make sure next loop LDS is ready for use
// make sure next loop LDS is ready for use
block_sync_lds
();
block_sync_lds
();
}
}
// if(threadIdx.x == 0)
// printf("xxx bid:%d, xx_total_iter_length:%d \n", static_cast<int>(blockIdx.x),
// xx_total_iter_length);
}
}
template
<
typename
Layout
>
template
<
typename
Layout
>
...
...
include/ck/utility/magic_division.hpp
View file @
b2a49620
...
@@ -162,7 +162,7 @@ struct MDiv
...
@@ -162,7 +162,7 @@ struct MDiv
// 1 dword -> 3 dword storage
// 1 dword -> 3 dword storage
uint32_t
divisor
;
uint32_t
divisor
;
uint32_t
multiplier
;
uint32_t
multiplier
;
uint32_t
shift
;
uint32_t
shift
;
// TODO: 8 bit is enough
// prefer construct on host
// prefer construct on host
__host__
__device__
MDiv
(
uint32_t
divisor_
)
:
divisor
(
divisor_
)
__host__
__device__
MDiv
(
uint32_t
divisor_
)
:
divisor
(
divisor_
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment