Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2dfbfbbc
Unverified
Commit
2dfbfbbc
authored
Oct 19, 2023
by
Chao Liu
Committed by
GitHub
Oct 19, 2023
Browse files
Revert "slice kv, and use 3d padding LDS layout (#15)" (#18)
This reverts commit
7b1a0b7f
.
parent
9f36ac7c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
80 deletions
+30
-80
example/91_tile_program/batched_gemm_softmax_gemm.cpp
example/91_tile_program/batched_gemm_softmax_gemm.cpp
+1
-3
example/91_tile_program/batched_gemm_softmax_gemm.hpp
example/91_tile_program/batched_gemm_softmax_gemm.hpp
+2
-4
example/91_tile_program/gemm_softmax_gemm_impl.hpp
example/91_tile_program/gemm_softmax_gemm_impl.hpp
+27
-73
No files found.
example/91_tile_program/batched_gemm_softmax_gemm.cpp
View file @
2dfbfbbc
...
@@ -101,7 +101,6 @@ int main(int argc, char* argv[])
...
@@ -101,7 +101,6 @@ int main(int argc, char* argv[])
constexpr
ck
::
index_t
kN0PerBlock
=
128
;
constexpr
ck
::
index_t
kN0PerBlock
=
128
;
constexpr
ck
::
index_t
kK0PerBlock
=
32
;
constexpr
ck
::
index_t
kK0PerBlock
=
32
;
constexpr
ck
::
index_t
kN1PerBlock
=
128
;
constexpr
ck
::
index_t
kN1PerBlock
=
128
;
constexpr
ck
::
index_t
kK1PerBlock
=
32
;
constexpr
ck
::
index_t
kBlockSize
=
256
;
constexpr
ck
::
index_t
kBlockSize
=
256
;
ck
::
index_t
kGridSize
=
Batch
*
(
M0
/
kM0PerBlock
)
*
(
N1
/
kN1PerBlock
);
ck
::
index_t
kGridSize
=
Batch
*
(
M0
/
kM0PerBlock
)
*
(
N1
/
kN1PerBlock
);
...
@@ -126,8 +125,7 @@ int main(int argc, char* argv[])
...
@@ -126,8 +125,7 @@ int main(int argc, char* argv[])
kM0PerBlock
,
kM0PerBlock
,
kN0PerBlock
,
kN0PerBlock
,
kK0PerBlock
,
kK0PerBlock
,
kN1PerBlock
,
kN1PerBlock
>
{},
kK1PerBlock
>
{},
kGridSize
,
kGridSize
,
kBlockSize
,
kBlockSize
,
0
,
0
,
...
...
example/91_tile_program/batched_gemm_softmax_gemm.hpp
View file @
2dfbfbbc
...
@@ -34,8 +34,7 @@ template <typename QDataType,
...
@@ -34,8 +34,7 @@ template <typename QDataType,
ck
::
index_t
kM0PerBlock
,
ck
::
index_t
kM0PerBlock
,
ck
::
index_t
kN0PerBlock
,
ck
::
index_t
kN0PerBlock
,
ck
::
index_t
kK0PerBlock
,
ck
::
index_t
kK0PerBlock
,
ck
::
index_t
kN1PerBlock
,
ck
::
index_t
kN1PerBlock
>
ck
::
index_t
kK1PerBlock
>
struct
BatchedGemmSoftmaxGemm
struct
BatchedGemmSoftmaxGemm
{
{
__device__
void
operator
()(
const
QDataType
*
q_ptr
,
__device__
void
operator
()(
const
QDataType
*
q_ptr
,
...
@@ -90,8 +89,7 @@ struct BatchedGemmSoftmaxGemm
...
@@ -90,8 +89,7 @@ struct BatchedGemmSoftmaxGemm
kM0PerBlock
,
kM0PerBlock
,
kN0PerBlock
,
kN0PerBlock
,
kK0PerBlock
,
kK0PerBlock
,
kN1PerBlock
,
kN1PerBlock
>
{};
kK1PerBlock
>
{};
kernel_impl
(
q_ptr
+
iBatch
*
BatchStrideQ
,
kernel_impl
(
q_ptr
+
iBatch
*
BatchStrideQ
,
k_ptr
+
iBatch
*
BatchStrideK
,
k_ptr
+
iBatch
*
BatchStrideK
,
...
...
example/91_tile_program/gemm_softmax_gemm_impl.hpp
View file @
2dfbfbbc
...
@@ -11,7 +11,6 @@
...
@@ -11,7 +11,6 @@
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/slice_tile.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
...
@@ -33,8 +32,7 @@ template <typename QDataType,
...
@@ -33,8 +32,7 @@ template <typename QDataType,
ck
::
index_t
kM0PerBlock
,
ck
::
index_t
kM0PerBlock
,
ck
::
index_t
kN0PerBlock
,
ck
::
index_t
kN0PerBlock
,
ck
::
index_t
kK0PerBlock
,
ck
::
index_t
kK0PerBlock
,
ck
::
index_t
kN1PerBlock
,
ck
::
index_t
kN1PerBlock
>
ck
::
index_t
kK1PerBlock
>
struct
GemmSoftmaxGemmImpl
struct
GemmSoftmaxGemmImpl
{
{
// block gemm0 pipeline
// block gemm0 pipeline
...
@@ -54,7 +52,7 @@ struct GemmSoftmaxGemmImpl
...
@@ -54,7 +52,7 @@ struct GemmSoftmaxGemmImpl
VDataType
,
VDataType
,
OaccDataType
,
OaccDataType
,
kBlockSize
,
kBlockSize
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN1PerBlock
,
k
K1
PerBlock
>>
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN1PerBlock
,
k
N0
PerBlock
>>
,
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
#if 0
#if 0
...
@@ -71,7 +69,7 @@ struct GemmSoftmaxGemmImpl
...
@@ -71,7 +69,7 @@ struct GemmSoftmaxGemmImpl
return b_lds_desc;
return b_lds_desc;
}
}
#el
if
0
#el
se
// fake XOR
// fake XOR
__device__
static
constexpr
auto
MakeVLdsBlockDescriptor
()
__device__
static
constexpr
auto
MakeVLdsBlockDescriptor
()
{
{
...
@@ -103,34 +101,6 @@ struct GemmSoftmaxGemmImpl
...
@@ -103,34 +101,6 @@ struct GemmSoftmaxGemmImpl
return
b_lds_desc_n_k
;
return
b_lds_desc_n_k
;
}
}
#else
// 3d, with padding
__device__
static
constexpr
auto
MakeVLdsBlockDescriptor
()
{
using
namespace
ck
;
// using BDataType = B1DataType;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
kK1PerBlock
;
constexpr
index_t
kPad
=
1
;
constexpr
index_t
kK1
=
8
;
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
kKPerBlock
/
kK1
>
{},
Number
<
kNPerBlock
>
{},
Number
<
kK1
>
{}),
make_tuple
(
Number
<
(
kNPerBlock
+
kPad
)
*
kK1
>
{},
Number
<
kK1
>
{},
Number
<
1
>
{}),
Number
<
kK1
>
{},
Number
<
1
>
{});
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kNPerBlock
),
make_merge_transform
(
make_tuple
(
Number
<
kKPerBlock
/
kK1
>
{},
Number
<
kK1
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
#endif
#endif
__device__
static
constexpr
auto
MakeVDramTileDistribution
()
__device__
static
constexpr
auto
MakeVDramTileDistribution
()
...
@@ -141,7 +111,7 @@ struct GemmSoftmaxGemmImpl
...
@@ -141,7 +111,7 @@ struct GemmSoftmaxGemmImpl
using
BDataType
=
VDataType
;
using
BDataType
=
VDataType
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
k
K1
PerBlock
;
constexpr
index_t
kKPerBlock
=
k
N0
PerBlock
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
@@ -211,7 +181,7 @@ struct GemmSoftmaxGemmImpl
...
@@ -211,7 +181,7 @@ struct GemmSoftmaxGemmImpl
auto
v_dram_window
=
auto
v_dram_window
=
make_tile_window
(
v_dram
,
make_tile_window
(
v_dram
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
k
K1
PerBlock
>
{}),
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
k
N0
PerBlock
>
{}),
{
iN1
,
0
},
{
iN1
,
0
},
MakeVDramTileDistribution
());
MakeVDramTileDistribution
());
...
@@ -221,7 +191,7 @@ struct GemmSoftmaxGemmImpl
...
@@ -221,7 +191,7 @@ struct GemmSoftmaxGemmImpl
MakeVLdsBlockDescriptor
());
MakeVLdsBlockDescriptor
());
auto
v_lds_window
=
make_tile_window
(
auto
v_lds_window
=
make_tile_window
(
v_lds
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
k
K1
PerBlock
>
{}),
{
0
,
0
});
v_lds
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
k
N0
PerBlock
>
{}),
{
0
,
0
});
// Block GEMM0 pipeline and Block GEMM1
// Block GEMM0 pipeline and Block GEMM1
constexpr
auto
gemm0_pipeline
=
BlockGemm0Pipeline
{};
constexpr
auto
gemm0_pipeline
=
BlockGemm0Pipeline
{};
...
@@ -244,10 +214,7 @@ struct GemmSoftmaxGemmImpl
...
@@ -244,10 +214,7 @@ struct GemmSoftmaxGemmImpl
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
Sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
SBlockTileType
{},
Sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm1
(
using
OaccBlockTileType
=
decltype
(
gemm1
(
PBlockTileType
{},
v_dram_window
));
get_slice_tile
(
PBlockTileType
{},
Sequence
<
0
,
0
>
{},
Sequence
<
kM0PerBlock
,
kK1PerBlock
>
{}),
v_dram_window
));
// init Oacc, M, L
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
o_acc
=
OaccBlockTileType
{};
...
@@ -272,9 +239,6 @@ struct GemmSoftmaxGemmImpl
...
@@ -272,9 +239,6 @@ struct GemmSoftmaxGemmImpl
const
auto
s
=
const
auto
s
=
tile_elementwise_in
(
type_convert
<
SMPLComputeDataType
,
SaccDataType
>
,
s_acc
);
tile_elementwise_in
(
type_convert
<
SMPLComputeDataType
,
SaccDataType
>
,
s_acc
);
// prefetch load v tile
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// m_local = rowmax(S{j})
// m_local = rowmax(S{j})
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
Sequence
<
1
>
{},
f_max
,
NumericLimits
<
SMPLComputeDataType
>::
Lowest
());
s
,
Sequence
<
1
>
{},
f_max
,
NumericLimits
<
SMPLComputeDataType
>::
Lowest
());
...
@@ -322,55 +286,45 @@ struct GemmSoftmaxGemmImpl
...
@@ -322,55 +286,45 @@ struct GemmSoftmaxGemmImpl
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// FIXME: this use different equation from FA v2 paper,
// but produce correc
t
result.
// but produce correc result.
// Is the equation wrong?
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
});
});
block_sync_lds
();
store_tile
(
v_lds_window
,
v_prefetch
);
move_tile_window
(
v_dram_window
,
{
0
,
kK1PerBlock
});
// type cast Pcompute{j} into P{j}
// type cast Pcompute{j} into P{j}
const
auto
p
=
const
auto
p
=
tile_elementwise_in
(
type_convert
<
PDataType
,
SMPLComputeDataType
>
,
p_compute
);
tile_elementwise_in
(
type_convert
<
PDataType
,
SMPLComputeDataType
>
,
p_compute
);
// Oacc{j}
// Block GEMM1: Oacc{j} += P{j} * V{j}
constexpr
index_t
k1_loops
=
kN0PerBlock
/
kK1PerBlock
;
if
constexpr
(
k1_loops
>
1
)
{
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
// load V{j}
const
auto
v
=
load_tile
(
v_dram_window
);
// load next v
const
auto
v
=
load_tile
(
v_dram_window
);
block_sync_lds
();
gemm1
(
o_acc
,
// wait for gemm0 pipeline to finish
get_slice_tile
(
p
,
Sequence
<
0
,
i_k1
*
kK1PerBlock
>
{},
Sequence
<
kM0PerBlock
,
(
i_k1
+
1
)
*
kK1PerBlock
>
{}),
v_lds_window
);
block_sync_lds
();
block_sync_lds
();
store_tile
(
v_lds_window
,
v
);
store_tile
(
v_lds_window
,
v
);
move_tile_window
(
v_dram_window
,
{
0
,
kK1PerBlock
});
});
// wait for store_tile to finish
}
// tail
{
block_sync_lds
();
block_sync_lds
();
gemm1
(
o_acc
,
get_slice_tile
(
p
,
// Oacc{j} += P{j} * V{j}
Sequence
<
0
,
(
k1_loops
-
1
)
*
kK1PerBlock
>
{},
gemm1
(
o_acc
,
p
,
v_lds_window
);
Sequence
<
kM0PerBlock
,
kN0PerBlock
>
{}),
v_lds_window
);
// wait for gemm1 to finish
block_sync_lds
();
block_sync_lds
();
}
}
// move tile windows
// move tile windows
move_tile_window
(
k_dram_window
,
{
kN0PerBlock
,
0
});
move_tile_window
(
k_dram_window
,
{
kN0PerBlock
,
0
});
move_tile_window
(
v_dram_window
,
{
0
,
kN0PerBlock
});
iN0
+=
kN0PerBlock
;
iN0
+=
kN0PerBlock
;
}
while
(
iN0
<
N0
);
}
while
(
iN0
<
N0
);
// O
acc
// O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
GetDistributedSpans
();
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
GetDistributedSpans
();
sweep_tile_span
(
o_spans
[
I0
],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
o_spans
[
I0
],
[
&
](
auto
idx0
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment