Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cfc2be07
Commit
cfc2be07
authored
Jul 03, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
30e4f4eb
497ccb87
Changes
257
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
445 additions
and
45 deletions
+445
-45
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-1
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+69
-0
include/ck/utility/amd_wave_read_first_lane.hpp
include/ck/utility/amd_wave_read_first_lane.hpp
+23
-1
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+82
-0
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+2
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+1
-1
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+17
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+11
-4
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+7
-4
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+8
-1
include/ck_tile/core/numeric/null_type.hpp
include/ck_tile/core/numeric/null_type.hpp
+13
-0
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+41
-2
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+15
-10
include/ck_tile/host/reference/reference_layernorm2d.hpp
include/ck_tile/host/reference/reference_layernorm2d.hpp
+69
-0
include/ck_tile/host/timer.hpp
include/ck_tile/host/timer.hpp
+5
-5
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+10
-0
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+50
-15
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+18
-1
No files found.
include/ck/utility/amd_buffer_addressing.hpp
View file @
cfc2be07
...
...
@@ -991,7 +991,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"s"
(
src_resource
));
"s"
(
src_resource
)
:
"memory"
);
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
...
...
include/ck/utility/amd_smfmac.hpp
0 → 100644
View file @
cfc2be07
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#pragma once
namespace
ck
{
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32f16
;
template
<
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32bf16
;
template
<
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_32x32x16f16
;
template
<
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_32x32x16bf16
;
template
<
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
}
// namespace ck
include/ck/utility/amd_wave_read_first_lane.hpp
View file @
cfc2be07
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -95,11 +95,33 @@ using get_carrier_t = typename get_carrier<SizeInBytes>::type;
}
// namespace detail
__device__
inline
uint32_t
amd_wave_read_first_lane
(
uint32_t
value
)
{
return
__builtin_amdgcn_readfirstlane
(
value
);
}
__device__
inline
int32_t
amd_wave_read_first_lane
(
int32_t
value
)
{
return
__builtin_amdgcn_readfirstlane
(
value
);
}
__device__
inline
int64_t
amd_wave_read_first_lane
(
int64_t
value
)
{
constexpr
unsigned
object_size
=
sizeof
(
int64_t
);
constexpr
unsigned
second_part_offset
=
object_size
/
2
;
auto
*
const
from_obj
=
reinterpret_cast
<
const
std
::
byte
*>
(
&
value
);
alignas
(
int64_t
)
std
::
byte
to_obj
[
object_size
];
using
Sgpr
=
uint32_t
;
*
reinterpret_cast
<
Sgpr
*>
(
to_obj
)
=
amd_wave_read_first_lane
(
*
reinterpret_cast
<
const
Sgpr
*>
(
from_obj
));
*
reinterpret_cast
<
Sgpr
*>
(
to_obj
+
second_part_offset
)
=
amd_wave_read_first_lane
(
*
reinterpret_cast
<
const
Sgpr
*>
(
from_obj
+
second_part_offset
));
return
*
reinterpret_cast
<
int64_t
*>
(
to_obj
);
}
template
<
typename
Object
,
typename
=
std
::
enable_if_t
<
std
::
is_class_v
<
Object
>
&&
std
::
is_trivially_copyable_v
<
Object
>>>
...
...
include/ck/utility/amd_wmma.hpp
View file @
cfc2be07
...
...
@@ -257,5 +257,87 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
}
};
// gfx12
/********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
// src: fp16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf8_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: iu8, dst: i32
template
<
index_t
MPerWave
,
index_t
NPerWave
,
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
;
template
<
bool
neg_a
,
bool
neg_b
,
bool
clamp
>
struct
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
<
16
,
16
,
neg_a
,
neg_b
,
clamp
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
int8x8_t
&
reg_a
,
const
int8x8_t
&
reg_b
,
FloatC
&
reg_c
)
{
#if defined(__gfx12__)
reg_c
.
template
AsType
<
int32x8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12
(
neg_a
,
bit_cast
<
int32x2_t
>
(
reg_a
),
neg_b
,
bit_cast
<
int32x2_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x8_t
>()[
Number
<
0
>
{}],
clamp
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
}
// namespace ck
#endif
include/ck/utility/array.hpp
View file @
cfc2be07
...
...
@@ -36,6 +36,8 @@ struct Array
return
*
this
;
}
__host__
__device__
constexpr
const
TData
*
begin
()
const
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
const
TData
*
end
()
const
{
return
&
mData
[
NSize
];
}
};
// empty Array
...
...
include/ck/utility/data_type.hpp
View file @
cfc2be07
...
...
@@ -203,7 +203,7 @@ struct vector_type<T, 1>
}
};
int
static
err
=
0
;
__device__
int
static
err
=
0
;
template
<
typename
T
>
struct
vector_type
<
T
,
2
>
{
...
...
include/ck/utility/synchronization.hpp
View file @
cfc2be07
...
...
@@ -10,12 +10,20 @@ namespace ck {
__device__
void
block_sync_lds
()
{
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
__builtin_amdgcn_s_barrier
();
#endif
#else
__syncthreads
();
#endif
...
...
@@ -23,11 +31,20 @@ __device__ void block_sync_lds()
__device__
void
block_sync_lds_direct_load
()
{
#ifdef __gfx12__
asm
volatile
(
"\
s_wait_vmcnt 0x0
\n
\
s_wait_dscnt 0x0
\n
\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"\
s_waitcnt vmcnt(0)
\n
\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
#endif
}
__device__
void
s_nop
()
...
...
include/ck_tile/core.hpp
View file @
cfc2be07
...
...
@@ -27,6 +27,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
cfc2be07
...
...
@@ -26,7 +26,12 @@ struct __attribute__((packed)) buffer_resource
CK_TILE_DEVICE
int32x4_t
make_wave_buffer_resource
(
const
void
*
ptr
,
uint32_t
size
=
0xffffffff
)
{
buffer_resource
res
{
ptr
,
size
,
CK_TILE_BUFFER_RESOURCE_3RD_DWORD
};
return
__builtin_bit_cast
(
int32x4_t
,
res
);
int32x4_t
r
=
__builtin_bit_cast
(
int32x4_t
,
res
);
r
.
x
=
__builtin_amdgcn_readfirstlane
(
r
.
x
);
r
.
y
=
__builtin_amdgcn_readfirstlane
(
r
.
y
);
r
.
z
=
__builtin_amdgcn_readfirstlane
(
r
.
z
);
r
.
w
=
__builtin_amdgcn_readfirstlane
(
r
.
w
);
return
r
;
}
namespace
impl
{
...
...
@@ -552,8 +557,9 @@ namespace impl{
template
<
index_t
N
>
CK_TILE_DEVICE
void
insert_dummy_dep_per_dword
(
array
<
float
,
N
>&
b
)
{
static_for
<
0
,
b
.
size
(),
1
>
{}([
&
](
auto
i
){
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
i
))
:
"memory"
);
constexpr
auto
kSize
=
remove_cvref_t
<
decltype
(
b
)
>::
size
();
static_for
<
0
,
kSize
,
1
>
{}([
&
](
auto
i
){
asm
volatile
(
" "
:
:
"v"
(
b
.
get
(
number
<
i
>
{}))
:
"memory"
);
});
}
#if 1
...
...
@@ -2103,7 +2109,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"s"
(
src_resource
));
"s"
(
src_resource
)
:
"memory"
);
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
...
...
include/ck_tile/core/arch/arch.hpp
View file @
cfc2be07
...
...
@@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
CK_TILE_DEVICE
void
block_sync_lds
()
{
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt
(
0xc07f
);
__builtin_amdgcn_s_barrier
();
#else
__syncthreads
();
#endif
...
...
include/ck_tile/core/config.hpp
View file @
cfc2be07
...
...
@@ -17,6 +17,9 @@
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__)
#define __gfx11__
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define __gfx12__
#endif
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
...
...
@@ -155,7 +158,7 @@
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) // for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx11__) // for GPU code
#elif defined(__gfx11__)
|| defined(__gfx12__)
// for GPU code
#define CK_TILE_BUFFER_RESOURCE_3RD_DWORD 0x31004000
#endif
...
...
@@ -167,6 +170,10 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif
// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
...
...
include/ck_tile/core/numeric/null_type.hpp
0 → 100644
View file @
cfc2be07
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace
ck_tile
{
struct
null_type
{
};
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
cfc2be07
...
...
@@ -110,7 +110,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace
impl
{
// TODO: this is ugly
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
x4
(
const
InTensor
&
in_dstr_tensors
)
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
...
...
@@ -156,6 +156,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif
}
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp16_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
static_assert
(
thread_buffer_size
%
2
==
0
);
constexpr
index_t
thread_buffer_size_pk
=
thread_buffer_size
/
2
;
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
// TODO: this is rtz cvt, need be very careful
for
(
index_t
i
=
0
;
i
<
thread_buffer_size_pk
;
i
++
)
{
auto
o
=
__builtin_amdgcn_cvt_pkrtz
(
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
0
],
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
1
]);
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
0
)
=
o
.
x
;
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
1
)
=
o
.
y
;
}
return
out_dstr_tensor
;
#else
// fallback
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
typename
InTensor
::
DataType
>
,
in_dstr_tensors
);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...
...
@@ -229,8 +260,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
{
return
impl
::
cast_tile_pk_fp8
x4
<
DstType
,
SrcTensor
>
(
src_tensor
);
return
impl
::
cast_tile_pk_fp8
_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#if CK_TILE_USE_PK_FP16_TILE_CAST
else
if
constexpr
(
std
::
is_same_v
<
DstType
,
fp16_t
>
&&
std
::
is_same_v
<
typename
SrcTensor
::
DataType
,
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
2
==
0
))
{
return
impl
::
cast_tile_pk_fp16_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
{
...
...
include/ck_tile/host.hpp
View file @
cfc2be07
...
...
@@ -18,6 +18,7 @@
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
...
...
include/ck_tile/host/check_err.hpp
View file @
cfc2be07
...
...
@@ -56,8 +56,9 @@ check_err(const Range& out,
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
...
...
@@ -114,8 +115,9 @@ check_err(const Range& out,
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
...
...
@@ -173,8 +175,9 @@ check_err(const Range& out,
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
...
...
@@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
...
...
@@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
...
...
include/ck_tile/host/reference/reference_layernorm2d.hpp
0 → 100644
View file @
cfc2be07
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
InvStdDataType
>
void
reference_layernorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
BetaDataType
>&
beta_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
MeanDataType
>&
mean_m
,
HostTensor
<
InvStdDataType
>&
invStd_m
,
ComputeDataType
epsilon
)
{
auto
layernorm2d_fwd_func
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
int
count
=
0
;
ComputeDataType
mean
=
0
;
ComputeDataType
variance
=
0
;
ComputeDataType
divisor
=
0
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
++
count
;
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
delta
=
x
-
mean
;
mean
+=
delta
/
count
;
ComputeDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
}
// actual variance
variance
=
variance
/
count
;
divisor
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
1
)
/
ck_tile
::
sqrt
(
variance
+
epsilon
);
if
constexpr
(
!
std
::
is_same_v
<
MeanDataType
,
ck_tile
::
null_type
>
)
mean_m
(
m
)
=
ck_tile
::
type_convert
<
MeanDataType
>
(
mean
);
if
constexpr
(
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
)
invStd_m
(
m
)
=
ck_tile
::
type_convert
<
InvStdDataType
>
(
divisor
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ComputeDataType
beta
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
beta_n
(
n
));
auto
y
=
(
x
-
mean
)
*
divisor
;
y
=
y
*
gamma
+
beta
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
);
}
};
make_ParallelTensorFunctor
(
layernorm2d_fwd_func
,
mean_m
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/timer.hpp
View file @
cfc2be07
...
...
@@ -27,7 +27,7 @@ struct gpu_timer
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hip
Device
Synchronize
());
HIP_CHECK_ERROR
(
hip
Stream
Synchronize
(
s
));
HIP_CHECK_ERROR
(
hipEventRecord
(
start_evt
,
s
));
}
...
...
@@ -51,15 +51,15 @@ struct gpu_timer
struct
cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
start
(
const
hipStream_t
&
)
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hip
Device
Synchronize
());
HIP_CHECK_ERROR
(
hip
Stream
Synchronize
(
s
));
start_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
)
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
s
)
{
HIP_CHECK_ERROR
(
hip
Device
Synchronize
());
HIP_CHECK_ERROR
(
hip
Stream
Synchronize
(
s
));
stop_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
// return in ms
...
...
include/ck_tile/ops/fmha.hpp
View file @
cfc2be07
...
...
@@ -10,6 +10,10 @@
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
...
...
@@ -22,6 +26,12 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
cfc2be07
...
...
@@ -8,6 +8,20 @@
namespace
ck_tile
{
struct
NullBlockDropout
{
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
__host__
__device__
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
(
void
)
randval_dram_block_window_tmp
;
(
void
)
seqlen_qk_start
;
return
make_null_tile_window
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}));
}
};
struct
BlockDropout
{
CK_TILE_HOST_DEVICE
BlockDropout
(
index_t
i_batch
,
...
...
@@ -195,6 +209,42 @@ struct BlockDropout
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
const
int
start_m0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
0
>
{});
if
(
is_store_randval
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
// save to Global
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
});
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
});
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
};
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
...
...
@@ -232,23 +282,8 @@ struct BlockDropout
:
PComputeDataType
(
0
);
});
});
// save to Global
if
(
is_store_randval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
}
});
if
(
is_store_randval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
}
});
if
(
is_store_randval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
}
}
template
<
typename
BlockGemm
,
...
...
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
cfc2be07
...
...
@@ -299,6 +299,23 @@ struct SimplifiedGenericAttentionMask
}
}
template
<
index_t
TileHeight
,
index_t
TileWidth
>
CK_TILE_HOST_DEVICE
constexpr
auto
GetTileRangeAlongX
(
index_t
i_y
,
number
<
TileHeight
>
height
,
number
<
TileWidth
>
width
,
index_t
num_splits
,
index_t
i_split
)
const
{
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
x_total
/
num_splits
);
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_end
=
(
i_split
==
num_splits
-
1
?
x_total
:
split_start
+
x_per_split
);
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
ck_tile
::
min
(
origin_end
,
split_end
));
}
// to get the loop length along Y axis, return index:[start, end), end-start=length
// use this if need loop over Y axis tile by tile (like q-seqlen loopover)
// TODO: y_end still could be negative, so end-start could be negative(need check)
...
...
@@ -372,7 +389,7 @@ struct SimplifiedGenericAttentionMask
// index_t x_end = min(i_y + x, x_total);
bool
top_right_edge
=
i_x_end
>
min
(
i_y
+
x
,
x_total
);
// consider right pad
bool
bottom_left_edge
=
i_y_end
>
(
i_x
+
y
);
bool
bottom_left_edge
=
i_y_end
>
min
(
i_x
+
y
,
y_total
);
// consider bottom pad
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
return
top_right_edge
||
bottom_left_edge
;
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment