Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0144b4f4
Commit
0144b4f4
authored
May 30, 2024
by
letaoqin
Browse files
Merge branch 'develop' into jizhan/reduce_threadwise_multi_d
parents
300337cd
34f3dfdd
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
201 additions
and
71 deletions
+201
-71
include/ck_tile/host/timer.hpp
include/ck_tile/host/timer.hpp
+79
-0
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+3
-3
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+2
-2
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
+55
-4
test/position_embedding/position_embedding.cpp
test/position_embedding/position_embedding.cpp
+62
-62
No files found.
include/ck_tile/host/timer.hpp
0 → 100644
View file @
0144b4f4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <chrono>
namespace
ck_tile
{
struct
gpu_timer
{
CK_TILE_HOST
gpu_timer
()
{
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start_evt
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop_evt
));
}
CK_TILE_HOST
~
gpu_timer
()
noexcept
(
false
)
{
HIP_CHECK_ERROR
(
hipEventDestroy
(
start_evt
));
HIP_CHECK_ERROR
(
hipEventDestroy
(
stop_evt
));
}
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
HIP_CHECK_ERROR
(
hipEventRecord
(
start_evt
,
s
));
}
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hipEventRecord
(
stop_evt
,
s
));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop_evt
));
}
// return in ms
CK_TILE_HOST
float
duration
()
const
{
float
ms
=
0
;
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
ms
,
start_evt
,
stop_evt
));
return
ms
;
}
private:
hipEvent_t
start_evt
,
stop_evt
;
};
struct
cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
start
(
const
hipStream_t
&
)
{
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
start_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
)
{
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
stop_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// return in ms
CK_TILE_HOST
float
duration
()
const
{
double
sec
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
duration
<
double
>>
(
stop_tick
-
start_tick
)
.
count
();
return
static_cast
<
float
>
(
sec
*
1e3
);
}
private:
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
start_tick
;
std
::
chrono
::
time_point
<
std
::
chrono
::
high_resolution_clock
>
stop_tick
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
View file @
0144b4f4
...
...
@@ -23,13 +23,13 @@ VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT:
TOP_LEFT
(but negative)
:
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT:
FROM_BOTTOM_RIGHT
(but negative)
:
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
...
...
@@ -54,7 +54,7 @@ struct Alibi
index_t
x_total_
,
AlibiMode
mode_
=
AlibiMode
::
VERTICAL
)
{
slope
=
mode_
==
AlibiMode
::
VERTICAL
?
slope_
:
-
slope
;
slope
=
mode_
==
AlibiMode
::
VERTICAL
?
slope_
:
-
slope
_
;
shift_left_up
=
[
&
]()
{
if
(
RowMajor
)
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
0144b4f4
...
...
@@ -76,7 +76,7 @@ struct FmhaFwdKernel
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK0BlockLength
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
...
...
@@ -702,7 +702,7 @@ struct FmhaFwdKernel
else
{
return
Alibi
<
SaccDataType
,
true
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
VERTICAL
};
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
FROM_BOTTOM_RIGHT
};
}
}
else
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
View file @
0144b4f4
...
...
@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
static
constexpr
const
char
*
name
=
"shb"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
...
...
@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner
}
};
template
<
typename
BlockFmhaShape_
>
using
FmhaFwdTilePartitioner_SHB
=
FmhaFwdTilePartitioner
<
BlockFmhaShape_
>
;
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner_HBS
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"hbs"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
test/position_embedding/position_embedding.cpp
View file @
0144b4f4
...
...
@@ -131,74 +131,74 @@ int main()
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
4
,
5
,
1
,
0
,
1
,
2
,
3
,
4
,
2
,
1
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
1
,
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
1
,
0
,
1
,
2
,
2
,
1
,
0
,
1
,
3
,
2
,
1
,
0
,
4
,
3
,
2
,
1
,
5
,
4
,
3
,
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
1
,
0
,
1
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
1
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
1
,
2
,
4
,
3
,
2
,
1
,
0
,
1
,
5
,
4
,
3
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
0
,
1
,
2
,
3
,
1
,
0
,
1
,
2
,
2
,
1
,
0
,
1
,
3
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
1
,
2
,
1
,
0
,
1
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
5
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
-
1
,
0
,
-
1
,
-
2
,
-
2
,
-
1
,
0
,
-
1
,
-
3
,
-
2
,
-
1
,
0
,
-
4
,
-
3
,
-
2
,
-
1
,
-
5
,
-
4
,
-
3
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
4
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
5
,
-
4
,
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
3
,
-
4
,
-
5
,
-
1
,
-
2
,
-
3
,
-
4
,
0
,
-
1
,
-
2
,
-
3
,
-
1
,
0
,
-
1
,
-
2
,
-
2
,
-
1
,
0
,
-
1
,
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
-
1
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
VERTICAL
,
{
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
4
,
5
,
1
,
0
,
1
,
2
,
3
,
4
,
2
,
1
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
1
,
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
1
,
0
,
1
,
2
,
2
,
1
,
0
,
1
,
3
,
2
,
1
,
0
,
4
,
3
,
2
,
1
,
5
,
4
,
3
,
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
1
,
0
,
1
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
1
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
1
,
2
,
4
,
3
,
2
,
1
,
0
,
1
,
5
,
4
,
3
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
0
,
1
,
2
,
3
,
1
,
0
,
1
,
2
,
2
,
1
,
0
,
1
,
3
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
1
,
2
,
1
,
0
,
1
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
5
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
4
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
3
,
-
1
,
0
,
-
1
,
-
2
,
-
2
,
-
1
,
0
,
-
1
,
-
3
,
-
2
,
-
1
,
0
,
-
4
,
-
3
,
-
2
,
-
1
,
-
5
,
-
4
,
-
3
,
-
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
-
1
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
3
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
4
,
-
3
,
-
2
,
-
1
,
0
,
-
1
,
-
5
,
-
4
,
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
-
2
,
-
3
,
-
4
,
-
5
,
-
1
,
-
2
,
-
3
,
-
4
,
0
,
-
1
,
-
2
,
-
3
,
-
1
,
0
,
-
1
,
-
2
,
-
2
,
-
1
,
0
,
-
1
,
-
3
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
-
1
,
-
2
,
-
1
,
0
,
-
1
,
-
2
,
-
1
,
0
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
8
,
{
0.5
,
0.25
,
0.125
,
0.0625
,
0.03125
,
0.015625
,
0.0078125
,
0.00390625
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
16
,
{
0.7071067811865476
,
0.5
,
0.35355339059327384
,
0.25000000000000006
,
0.17677669529663692
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment