Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
86580888
Commit
86580888
authored
Jun 16, 2022
by
raman jana
Browse files
fixes for global-write for math-wave
parent
a607bc1a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
66 deletions
+40
-66
include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
.../tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
+27
-54
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
+13
-12
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp
View file @
86580888
...
...
@@ -11,15 +11,15 @@ struct GridwiseGemmLoadWave;
template
<
typename
TileLoadThreadGroup
>
struct
GridwiseGemmLoadWave
<
TileLoadThreadGroup
,
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/*
num_loop
*/
)
{
// TODO: improve applicability
return
num_loop
%
2
==
0
;
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
/
2
>
1
;
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
...
...
@@ -37,35 +37,29 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
typename
BBlockTransferStep
>
static
__device__
void
RunLoadWavePipeline
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_block_copy
,
ABlockTransfer
&
a_block
wise
_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_block_copy
,
BBlockTransfer
&
b_block
wise
_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
index_t
num_loop
)
{
// global read 0
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_block
wise
_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_block
wise
_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
//move to 1
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_block
wise
_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block
wise
_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
//LDS write 0
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global Read 1
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
// LDS write 0
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global Read 1
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
if
constexpr
(
HasMainLoop
)
{
...
...
@@ -75,43 +69,31 @@ struct GridwiseGemmLoadWave<TileLoadThreadGroup, 1>
{
//sync for Load threads()
block_sync_lds
();
// global read i + 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// move to i + 2
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
//?? what is this for
// sync with math threads()
block_sync_lds
();
// move to i + 2
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
a_block_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
// global read i + 2
a_block_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
//LDS write i+1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// LDS write i + 1
b_block_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global read i + 2
b_block_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
//what is this for??
block_sync_lds
();
// move to i + 2
a_block_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_block_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
block_sync_lds
();
// GEMM num_loop
}
...
...
@@ -126,15 +108,14 @@ template <typename TileMathThreadGroup>
struct
GridwiseGemmMathWave
<
TileMathThreadGroup
,
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/*
num_loop
*/
)
{
// TODO: improve applicability
return
num_loop
%
2
==
0
;
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
/
2
>
1
;
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
...
...
@@ -165,24 +146,16 @@ struct GridwiseGemmMathWave<TileMathThreadGroup, 1>
block_sync_lds
();
++
i
;
}
while
(
i
<
(
num_loop
-
2
));
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// LDS write num_loop - 1
block_sync_lds
();
// GEMM num_loop - 1
block_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp
View file @
86580888
...
...
@@ -137,10 +137,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
}
__device__
static
constexpr
bool
IsBelong
()
{
return
(
get_thread_local_1d_id
()
<
TileLoadThreadGroupSize
);
return
(
get_thread_local_1d_id
()
>=
TileLoadThreadGroupSize
);
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
();
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
()
-
TileMathThreadGroupSize
;
}
};
...
...
@@ -152,10 +152,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
}
__device__
static
constexpr
bool
IsBelong
()
{
return
get_thread_local_1d_id
()
>=
Tile
Load
ThreadGroupSize
;
return
get_thread_local_1d_id
()
<
Tile
Math
ThreadGroupSize
;
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
()
-
TileMathThreadGroupSize
;
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
();
}
};
using
CShuffleBlockTransferThreadGroup
=
...
...
@@ -476,11 +476,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
b_block_buf
,
b_block_slice_copy_step
,
num_k_block_main_loop
);
block_sync_lds
();
block_sync_lds
();
}
else
if
(
TileMathThreadGroup
::
IsBelong
())
{
//branch early for math wave
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
...
...
@@ -507,7 +511,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
c_thread_buf
,
num_k_block_main_loop
);
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
...
...
@@ -691,7 +694,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
...
...
@@ -708,7 +710,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment