Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3e2c827d
Commit
3e2c827d
authored
Jan 20, 2024
by
Tri Dao
Browse files
Remove unused kernel_traits file
parent
66a127ae
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
159 deletions
+0
-159
csrc/flash_attn/src/kernel_traits_sm90.h
csrc/flash_attn/src/kernel_traits_sm90.h
+0
-159
No files found.
csrc/flash_attn/src/kernel_traits_sm90.h
deleted
100644 → 0
View file @
66a127ae
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>
using
namespace
cute
;
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
typename
elem_type
=
cutlass
::
half_t
>
struct
Flash_kernel_traits_sm90
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
Element
=
elem_type
;
static
constexpr
bool
Has_cp_async
=
true
;
#else
using
Element
=
cutlass
::
half_t
;
static
constexpr
bool
Has_cp_async
=
false
;
#endif
using
ElementAccum
=
float
;
using
index_t
=
uint32_t
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using
MMA_Atom_Arch
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
SM80_16x8x16_F32F16F16F32_TN
>
,
MMA_Atom
<
SM80_16x8x16_F32BF16BF16F32_TN
>
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_2
,
_1
>>
;
#else
using
MMA_Atom_Arch
=
MMA_Atom
<
SM75_16x8x8_F32F16F16F32_TN
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_2
,
_2
>>
;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using
SmemCopyAtom
=
Copy_Atom
<
SM75_U32x4_LDSM_N
,
elem_type
>
;
using
SmemCopyAtomTransposed
=
Copy_Atom
<
SM75_U16x8_LDSM_T
,
elem_type
>
;
#else
using
SmemCopyAtom
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
using
SmemCopyAtomTransposed
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
#endif
};
template
<
int
kHeadDim_
,
int
kBlockM_
,
int
kBlockN_
,
int
kNWarps_
,
bool
Is_Q_in_regs_
=
false
,
bool
Share_Q_K_smem_
=
false
,
typename
elem_type
=
cutlass
::
half_t
,
typename
Base
=
Flash_kernel_traits_sm90
<
kHeadDim_
,
kBlockM_
,
kBlockN_
,
kNWarps_
,
elem_type
>
>
struct
Flash_fwd_kernel_traits
:
public
Base
{
using
Element
=
typename
Base
::
Element
;
using
ElementAccum
=
typename
Base
::
ElementAccum
;
using
index_t
=
typename
Base
::
index_t
;
static
constexpr
bool
Has_cp_async
=
Base
::
Has_cp_async
;
using
SmemCopyAtom
=
typename
Base
::
SmemCopyAtom
;
using
SmemCopyAtomTransposed
=
typename
Base
::
SmemCopyAtomTransposed
;
static
constexpr
bool
Share_Q_K_smem
=
Share_Q_K_smem_
;
static
constexpr
bool
Is_Q_in_regs
=
Is_Q_in_regs_
||
Share_Q_K_smem
;
// The number of threads.
static
constexpr
int
kNWarps
=
kNWarps_
;
static
constexpr
int
kNThreads
=
kNWarps
*
32
;
static
constexpr
int
kBlockM
=
kBlockM_
;
static
constexpr
int
kBlockN
=
kBlockN_
;
static
constexpr
int
kHeadDim
=
kHeadDim_
;
static_assert
(
kHeadDim
%
32
==
0
);
static
constexpr
int
kBlockKSmem
=
kHeadDim
%
64
==
0
?
64
:
32
;
static
constexpr
int
kBlockKGmem
=
kHeadDim
%
128
==
0
?
128
:
(
kHeadDim
%
64
==
0
?
64
:
32
);
static
constexpr
int
kSwizzle
=
kBlockKSmem
==
32
?
2
:
3
;
using
TiledMma
=
TiledMMA
<
typename
Base
::
MMA_Atom_Arch
,
Layout
<
Shape
<
Int
<
kNWarps
>
,
_1
,
_1
>>
,
// 4x1x1 or 8x1x1 thread group
typename
Base
::
ValLayoutMNK
>
;
// 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using
SmemLayoutAtomQ
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout
<
Shape
<
_8
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
SmemLayoutKV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{}));
using
SmemLayoutAtomVtransposed
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout
<
Shape
<
Int
<
kBlockKSmem
>
,
Int
<
kBlockN
>>
,
Stride
<
_1
,
Int
<
kBlockKSmem
>>>
{}));
using
SmemLayoutVtransposed
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVtransposed
{},
Shape
<
Int
<
kHeadDim
>
,
Int
<
kBlockN
>>
{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
SmemLayoutVtransposed
{}.
layout_fn
());
using
SmemLayoutAtomO
=
decltype
(
composition
(
Swizzle
<
kSwizzle
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
kBlockKSmem
>>
,
Stride
<
Int
<
kBlockKSmem
>
,
_1
>>
{}));
using
SmemLayoutO
=
decltype
(
tile_to_shape
(
SmemLayoutAtomO
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
SmemCopyAtomO
=
Copy_Atom
<
DefaultCopy
,
elem_type
>
;
static
constexpr
int
kSmemQCount
=
size
(
SmemLayoutQ
{});
static
constexpr
int
kSmemKVCount
=
size
(
SmemLayoutKV
{})
*
2
;
static
constexpr
int
kSmemQSize
=
kSmemQCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemKVSize
=
kSmemKVCount
*
sizeof
(
Element
);
static
constexpr
int
kSmemSize
=
Share_Q_K_smem
?
std
::
max
(
kSmemQSize
,
kSmemKVSize
)
:
kSmemQSize
+
kSmemKVSize
;
static
constexpr
int
kGmemElemsPerLoad
=
sizeof
(
cute
::
uint128_t
)
/
sizeof
(
Element
);
static_assert
(
kHeadDim
%
kGmemElemsPerLoad
==
0
,
"kHeadDim must be a multiple of kGmemElemsPerLoad"
);
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static
constexpr
int
kGmemThreadsPerRow
=
kBlockKSmem
/
kGmemElemsPerLoad
;
static_assert
(
kNThreads
%
kGmemThreadsPerRow
==
0
,
"kNThreads must be a multiple of kGmemThreadsPerRow"
);
using
GmemLayoutAtom
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRow
>
,
Int
<
kGmemThreadsPerRow
>>
,
Stride
<
Int
<
kGmemThreadsPerRow
>
,
_1
>>
;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using
Gmem_copy_struct
=
std
::
conditional_t
<
Has_cp_async
,
SM80_CP_ASYNC_CACHEGLOBAL
<
cute
::
uint128_t
>
,
DefaultCopy
>
;
using
GmemTiledCopyQKV
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
elem_type
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per read
using
GmemTiledCopyO
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
elem_type
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
static
constexpr
int
kGmemThreadsPerRowP
=
kBlockN
/
kGmemElemsPerLoad
;
static_assert
(
kNThreads
%
kGmemThreadsPerRowP
==
0
,
"kNThreads must be a multiple of kGmemThreadsPerRowP"
);
using
GmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRowP
>
,
Int
<
kGmemThreadsPerRowP
>>
,
Stride
<
Int
<
kGmemThreadsPerRowP
>
,
_1
>>
;
using
GmemTiledCopyP
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
elem_type
>
{},
GmemLayoutAtomP
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per store
};
////////////////////////////////////////////////////////////////////////////////////////////////////
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment