Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5c9b21d8
Unverified
Commit
5c9b21d8
authored
Apr 16, 2021
by
yjk21
Committed by
GitHub
Apr 16, 2021
Browse files
adds fmhalib (#1074)
parent
e5f2f675
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5723 additions
and
0 deletions
+5723
-0
apex/contrib/csrc/fmha/fmha_api.cpp
apex/contrib/csrc/fmha/fmha_api.cpp
+305
-0
apex/contrib/csrc/fmha/src/fmha.h
apex/contrib/csrc/fmha/src/fmha.h
+92
-0
apex/contrib/csrc/fmha/src/fmha/gemm.h
apex/contrib/csrc/fmha/src/fmha/gemm.h
+317
-0
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
+426
-0
apex/contrib/csrc/fmha/src/fmha/kernel_traits.h
apex/contrib/csrc/fmha/src/fmha/kernel_traits.h
+95
-0
apex/contrib/csrc/fmha/src/fmha/mask.h
apex/contrib/csrc/fmha/src/fmha/mask.h
+76
-0
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
+1253
-0
apex/contrib/csrc/fmha/src/fmha/softmax.h
apex/contrib/csrc/fmha/src/fmha/softmax.h
+478
-0
apex/contrib/csrc/fmha/src/fmha/utils.h
apex/contrib/csrc/fmha/src/fmha/utils.h
+953
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
+60
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
+60
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
+60
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
+60
-0
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
+599
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
+58
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
+58
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
+57
-0
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
...ntrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
+56
-0
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
+338
-0
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
+322
-0
No files found.
apex/contrib/csrc/fmha/fmha_api.cpp
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "fmha.h"
void
run_fmha_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
run_fmha_dgrad_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
);
void
set_params
(
Fused_multihead_attention_fprop_params
&
params
,
// sizes
const
size_t
b
,
const
size_t
s
,
const
size_t
h
,
const
size_t
d
,
// device pointers
void
*
qkv_packed_d
,
void
*
cu_seqlens_d
,
void
*
seqlens_d
,
void
*
o_packed_d
,
void
*
s_d
,
float
p_dropout
)
{
Data_type
acc_type
=
DATA_TYPE_FP32
;
Data_type
data_type
=
DATA_TYPE_FP16
;
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
// Set the pointers and strides.
params
.
qkv_ptr
=
qkv_packed_d
;
params
.
qkv_stride_in_bytes
=
get_size_in_bytes
(
h
*
3
*
d
,
data_type
);
params
.
o_ptr
=
o_packed_d
;
params
.
o_stride_in_bytes
=
get_size_in_bytes
(
h
*
d
,
data_type
);
params
.
cu_seqlens
=
static_cast
<
int
*>
(
cu_seqlens_d
);
params
.
seqlens
=
static_cast
<
int
*>
(
seqlens_d
);
// S = softmax(P)
params
.
s_ptr
=
s_d
;
params
.
s_stride_in_bytes
=
get_size_in_bytes
(
b
*
h
*
s
,
data_type
);
// Set the dimensions.
params
.
b
=
b
;
params
.
h
=
h
;
params
.
s
=
s
;
params
.
d
=
d
;
// Set the different scale values.
const
float
scale_bmm1
=
1.
f
/
sqrtf
(
d
);
constexpr
float
scale_softmax
=
1.
f
;
constexpr
float
scale_bmm2
=
1.
f
;
set_alpha
(
params
.
scale_bmm1
,
scale_bmm1
,
acc_type
);
set_alpha
(
params
.
scale_softmax
,
scale_softmax
,
acc_type
);
set_alpha
(
params
.
scale_bmm2
,
scale_bmm2
,
data_type
);
// Set this to probability of keeping an element to simplify things.
params
.
p_dropout
=
1.
f
-
p_dropout
;
params
.
rp_dropout
=
1.
f
/
params
.
p_dropout
;
TORCH_CHECK
(
p_dropout
<
1.
f
);
set_alpha
(
params
.
scale_dropout
,
params
.
rp_dropout
,
data_type
);
}
constexpr
uint32_t
NUM_HEADS_DIM
=
2
;
constexpr
uint32_t
THREE_DIM
=
1
;
std
::
vector
<
at
::
Tensor
>
mha_fwd
(
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
at
::
Tensor
&
seqlens
,
// b
const
float
p_dropout
,
const
int
max_seq_len
,
const
bool
is_training
,
c10
::
optional
<
at
::
Generator
>
gen_
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
);
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_fp16_512_64_sm80
;
if
(
max_seq_len
<=
128
)
{
seq_len
=
128
;
launch
=
&
run_fmha_fp16_128_64_sm80
;
}
else
if
(
max_seq_len
<=
256
)
{
seq_len
=
256
;
launch
=
&
run_fmha_fp16_256_64_sm80
;
}
else
if
(
max_seq_len
<=
384
)
{
seq_len
=
384
;
launch
=
&
run_fmha_fp16_384_64_sm80
;
}
else
if
(
max_seq_len
<=
512
)
{
seq_len
=
512
;
launch
=
&
run_fmha_fp16_512_64_sm80
;
}
else
{
TORCH_CHECK
(
false
);
}
constexpr
int
warps_m
=
1
;
constexpr
int
warps_n
=
4
;
// this leads to an upper bound
const
int
mmas_m
=
seq_len
/
16
/
warps_m
;
const
int
mmas_n
=
seq_len
/
16
/
warps_n
;
const
int
elts_per_thread
=
8
*
mmas_m
*
mmas_n
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
qkv
.
is_cuda
())
TORCH_CHECK
(
cu_seqlens
.
is_cuda
())
TORCH_CHECK
(
qkv
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
())
TORCH_CHECK
(
seqlens
.
is_contiguous
())
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
TORCH_CHECK
(
seqlens
.
numel
()
==
batch_size
);
const
int
total
=
sizes
[
0
];
const
int
num_heads
=
sizes
[
NUM_HEADS_DIM
];
const
int
head_size
=
sizes
[
3
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
auto
opts
=
qkv
.
options
();
auto
ctx
=
torch
::
empty
({
total
,
num_heads
,
head_size
},
opts
);
auto
s
=
torch
::
empty
({
batch_size
,
num_heads
,
seq_len
,
seq_len
},
opts
);
auto
gen
=
at
::
get_generator_or_default
<
at
::
CUDAGeneratorImpl
>
(
gen_
,
at
::
cuda
::
detail
::
getDefaultCUDAGenerator
());
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
batch_size
,
seq_len
,
num_heads
,
head_size
,
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
seqlens
.
data_ptr
(),
ctx
.
data_ptr
(),
s
.
data_ptr
(),
p_dropout
);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t
counter_offset
=
elts_per_thread
;
at
::
PhiloxCudaState
rng_engine_inputs
;
if
(
is_training
)
{
// See Note [Acquire lock when using random generators]
std
::
lock_guard
<
std
::
mutex
>
lock
(
gen
->
mutex_
);
params
.
philox_args
=
gen
->
philox_cuda_state
(
counter_offset
);
}
launch
(
params
,
is_training
,
stream
);
return
{
ctx
,
s
};
}
std
::
vector
<
at
::
Tensor
>
mha_bwd
(
const
at
::
Tensor
&
dout
,
// total x num_heads, x head_size
const
at
::
Tensor
&
qkv
,
// total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
at
::
Tensor
&
softmax
,
// b x h x s x s softmax and dmask - will be overwritten with dP
const
at
::
Tensor
&
cu_seqlens
,
// b+1
const
at
::
Tensor
&
seqlens
,
// b
const
float
p_dropout
,
// probability to drop
const
int
max_seq_len
// max sequence length to choose the kernel
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
);
int
seq_len
=
512
;
auto
launch
=
&
run_fmha_dgrad_fp16_512_64_sm80
;
if
(
max_seq_len
<=
128
)
{
seq_len
=
128
;
launch
=
&
run_fmha_dgrad_fp16_128_64_sm80
;
}
else
if
(
max_seq_len
<=
256
)
{
seq_len
=
256
;
launch
=
&
run_fmha_dgrad_fp16_256_64_sm80
;
}
else
if
(
max_seq_len
<=
384
)
{
seq_len
=
384
;
launch
=
&
run_fmha_dgrad_fp16_384_64_sm80
;
}
else
if
(
max_seq_len
<=
512
)
{
seq_len
=
512
;
launch
=
&
run_fmha_dgrad_fp16_512_64_sm80
;
}
else
{
TORCH_CHECK
(
false
);
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK
(
qkv
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
dout
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
softmax
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
cu_seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
seqlens
.
dtype
()
==
torch
::
kInt32
);
TORCH_CHECK
(
qkv
.
is_cuda
());
TORCH_CHECK
(
cu_seqlens
.
is_cuda
());
TORCH_CHECK
(
qkv
.
is_contiguous
());
TORCH_CHECK
(
cu_seqlens
.
is_contiguous
());
TORCH_CHECK
(
seqlens
.
is_contiguous
());
TORCH_CHECK
(
cu_seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
seqlens
.
dim
()
==
1
);
TORCH_CHECK
(
qkv
.
dim
()
==
4
);
const
auto
sizes
=
qkv
.
sizes
();
TORCH_CHECK
(
sizes
[
THREE_DIM
]
==
3
);
const
int
batch_size
=
cu_seqlens
.
numel
()
-
1
;
TORCH_CHECK
(
seqlens
.
numel
()
==
batch_size
);
const
int
num_heads
=
sizes
[
NUM_HEADS_DIM
];
const
int
head_size
=
sizes
[
3
];
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
64
);
auto
dqkv
=
torch
::
empty_like
(
qkv
);
Fused_multihead_attention_fprop_params
params
;
set_params
(
params
,
batch_size
,
seq_len
,
num_heads
,
head_size
,
qkv
.
data_ptr
(),
cu_seqlens
.
data_ptr
(),
seqlens
.
data_ptr
(),
dout
.
data_ptr
(),
// we set o_ptr to dout
softmax
.
data_ptr
(),
// softmax gets overwritten by dP!
p_dropout
);
// we're re-using these scales scales
Data_type
acc_type
=
DATA_TYPE_FP32
;
set_alpha
(
params
.
scale_bmm1
,
1.
f
,
acc_type
);
set_alpha
(
params
.
scale_softmax
,
1.
f
/
sqrtf
(
head_size
),
acc_type
);
set_alpha
(
params
.
scale_bmm2
,
1.
f
,
DATA_TYPE_FP16
);
params
.
dqkv_ptr
=
dqkv
.
data_ptr
();
launch
(
params
,
stream
);
return
{
dqkv
,
softmax
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
doc
()
=
"Fused Multi-head Self-attention for BERT"
;
m
.
def
(
"fwd"
,
&
mha_fwd
,
"Forward pass"
);
m
.
def
(
"bwd"
,
&
mha_bwd
,
"Backward pass"
);
}
apex/contrib/csrc/fmha/src/fmha.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <fmha_utils.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Qkv_params
{
// The QKV matrices.
void
*
qkv_ptr
;
// The stride between rows of the Q, K and V matrices.
size_t
qkv_stride_in_bytes
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Fused_multihead_attention_fprop_params
:
public
Qkv_params
{
// The dQKV matrices.
void
*
dqkv_ptr
;
// The O matrix (output).
void
*
o_ptr
;
// The stride between rows of O.
int64_t
o_stride_in_bytes
;
// The pointer to the S matrix, overwritten by the dP matrix (bwd).
void
*
s_ptr
;
// The stride between rows of the S matrix.
int64_t
s_stride_in_bytes
;
// The dimensions.
int
b
,
h
,
s
,
d
;
// The scaling factors for the kernel.
uint32_t
scale_bmm1
,
scale_softmax
,
scale_bmm2
;
// array of length b+1 holding starting offset of each sequence.
int
*
cu_seqlens
;
// array of length b holding the actual sequence lenghts.
int
*
seqlens
;
// The dropout probability (probability of keeping an activation).
float
p_dropout
;
// Scale factor of 1 / (1 - p_dropout).
float
rp_dropout
;
// Scale factor of 1 / (1 - p_dropout), in half2.
uint32_t
scale_dropout
;
// Random state.
at
::
PhiloxCudaState
philox_args
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
apex/contrib/csrc/fmha/src/fmha/gemm.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#define FMHA_DIV_UP(m, n) (((m) + (n)-1) / (n))
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type_
,
int
NUM_ELTS_
,
int
BITS_PER_ELT_
,
int
ALIGNMENT_
>
struct
Fragment_base_
{
// The data type.
using
Data_type
=
Data_type_
;
// default input type
using
Input_type_
=
Data_type_
;
// Does it store the array of elements.
enum
{
HAS_ELTS
=
BITS_PER_ELT_
>=
8
};
// The number of elements.
enum
{
NUM_ELTS
=
NUM_ELTS_
};
// The size of element in bits.
enum
{
BITS_PER_ELT
=
BITS_PER_ELT_
};
// The size of byte of a single register.
enum
{
BYTES_PER_REG
=
4
};
// The size in bits.
enum
{
BITS_PER_REG
=
BYTES_PER_REG
*
8
};
// The number of registers needed to store the fragment.
enum
{
NUM_REGS
=
Div_up
<
NUM_ELTS
*
BITS_PER_ELT
,
BITS_PER_REG
>::
VALUE
};
// The size in bytes (as returned by sizeof(Fragment_base<>).
enum
{
SIZE_IN_BYTES
=
NUM_REGS
*
BYTES_PER_REG
};
// The alignment.
enum
{
ALIGNMENT
=
ALIGNMENT_
>
0
?
ALIGNMENT_
:
Min
<
NUM_REGS
*
BYTES_PER_REG
,
16
>::
VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The type of the elements.
typename
Data_type_
,
// The number of elements.
int
NUM_ELTS_
,
// The alignment if you want to force a value -- use 0 otherwise.
int
ALIGNMENT_
=
0
,
// The base class.
typename
Base_
=
Fragment_base_
<
Data_type_
,
NUM_ELTS_
,
8
*
sizeof
(
Data_type_
),
ALIGNMENT_
>
>
struct
alignas
(
static_cast
<
int
>
(
Base_
::
ALIGNMENT
))
Fragment
:
public
Base_
{
// The size of a load/store.
enum
{
BYTES_PER_LOAD_STORE
=
Base_
::
NUM_REGS
*
sizeof
(
uint32_t
)
};
// Clear the fragment. Using PTX in that code seems to produce better SASS...
inline
__device__
void
clear
()
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Base_
::
NUM_REGS
;
++
ii
)
{
asm
volatile
(
"mov.u32 %0, 0;
\n
"
:
"=r"
(
this
->
reg
(
ii
))
:
);
}
}
// Immutable access to a register.
inline
__device__
const
uint32_t
&
reg
(
int
ii
)
const
{
return
this
->
regs_
[
ii
];
}
// Mutable access to a register.
inline
__device__
uint32_t
&
reg
(
int
ii
)
{
return
this
->
regs_
[
ii
];
}
uint32_t
regs_
[
Base_
::
NUM_REGS
];
// Immutable access to the elements.
inline
__device__
const
Data_type_
&
elt
(
int
ii
)
const
{
return
reinterpret_cast
<
const
Data_type_
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Mutable access to the elements.
inline
__device__
Data_type_
&
elt
(
int
ii
)
{
return
reinterpret_cast
<
Data_type_
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Immutable access to the elements with a cast.
template
<
typename
Cast_type
>
inline
__device__
const
Cast_type
&
elt_as
(
int
ii
)
const
{
return
reinterpret_cast
<
const
Cast_type
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Mutable access to the elements.
template
<
typename
Cast_type
>
inline
__device__
Cast_type
&
elt_as
(
int
ii
)
{
return
reinterpret_cast
<
Cast_type
*>
(
&
this
->
regs_
[
0
])[
ii
];
}
// Add another fragment.
inline
__device__
void
add
(
const
Fragment
&
other
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
NUM_ELTS_
;
++
ii
)
{
this
->
elt
(
ii
)
+=
other
.
elt
(
ii
);
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Layout
>
struct
Fragment_a
:
public
Fragment
<
uint16_t
,
8
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Layout
>
struct
Fragment_b
:
public
Fragment
<
uint16_t
,
8
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Fragment_accumulator
:
public
Fragment
<
float
,
8
>
{
// The base class.
using
Base
=
Fragment
<
float
,
8
>
;
// Add two fragments.
template
<
typename
Other_fragment_
>
inline
__device__
void
add
(
const
Other_fragment_
&
other
)
{
for
(
int
ii
=
0
;
ii
<
Base
::
NUM_ELTS
;
++
ii
)
{
this
->
elt
(
ii
)
=
this
->
elt
(
ii
)
+
other
.
elt
(
ii
);
}
}
// Do the HMMA.
template
<
typename
Layout_a
,
typename
Layout_b
>
inline
__device__
void
mma
(
const
Fragment_a
<
Layout_a
>
&
a
,
const
Fragment_b
<
Layout_b
>
&
b
)
{
asm
volatile
(
\
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
\n
"
\
" {%0, %1, %2, %3},
\n
"
\
" {%4, %5, %6, %7},
\n
"
\
" {%8, %9},
\n
"
\
" {%0, %1, %2, %3};
\n
"
\
:
"+f"
(
elt
(
0
)),
"+f"
(
elt
(
1
)),
"+f"
(
elt
(
2
)),
"+f"
(
elt
(
3
))
:
"r"
(
a
.
reg
(
0
)),
"r"
(
a
.
reg
(
1
)),
"r"
(
a
.
reg
(
2
)),
"r"
(
a
.
reg
(
3
))
,
"r"
(
b
.
reg
(
0
)),
"r"
(
b
.
reg
(
1
)));
asm
volatile
(
\
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32
\n
"
\
" {%0, %1, %2, %3},
\n
"
\
" {%4, %5, %6, %7},
\n
"
\
" {%8, %9},
\n
"
\
" {%0, %1, %2, %3};
\n
"
\
:
"+f"
(
elt
(
4
)),
"+f"
(
elt
(
5
)),
"+f"
(
elt
(
6
)),
"+f"
(
elt
(
7
))
:
"r"
(
a
.
reg
(
0
)),
"r"
(
a
.
reg
(
1
)),
"r"
(
a
.
reg
(
2
)),
"r"
(
a
.
reg
(
3
))
,
"r"
(
b
.
reg
(
2
)),
"r"
(
b
.
reg
(
3
)));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Fragment
,
int
M
,
int
N
>
inline
__device__
void
clear
(
Fragment
(
&
frag
)[
M
][
N
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
frag
[
mi
][
ni
].
clear
();
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Accumulator_type
,
int
WARPS_K
>
struct
Clear_accumulator
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARPS_K
>
struct
Clear_accumulator
<
float
,
WARPS_K
>
{
template
<
typename
Acc
,
int
M
,
int
N
>
static
inline
__device__
void
apply
(
Acc
(
&
acc
)[
M
][
N
],
bool
=
false
)
{
fmha
::
clear
(
acc
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
inline
__device__
void
gemm
(
Acc
(
&
acc
)[
M
][
N
],
const
A
(
&
a
)[
M
],
const
B
(
&
b
)[
N
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
++
ni
)
{
acc
[
mi
][
ni
].
mma
(
a
[
mi
],
b
[
ni
]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The number of rows in the CTA tile.
int
M_
,
// The number of cols in the CTA tile.
int
N_
,
// The number of elements in the the K dimension of the GEMM loop.
int
K_
,
// The number of rows of warps.
int
WARPS_M_
,
// The number of cols of warps.
int
WARPS_N_
,
// The number of warps in the K dimension of the GEMM loop.
int
WARPS_K_
>
struct
Cta_tile_
{
enum
{
M
=
M_
,
N
=
N_
,
K
=
K_
};
// The number of warps.
enum
{
WARPS_M
=
WARPS_M_
,
WARPS_N
=
WARPS_N_
,
WARPS_K
=
WARPS_K_
};
// The number of warps per CTA.
enum
{
WARPS_PER_CTA
=
WARPS_M
*
WARPS_N
*
WARPS_K
};
// The number of threads per warp.
enum
{
THREADS_PER_WARP
=
32
};
// The number of threads per CTA.
enum
{
THREADS_PER_CTA
=
WARPS_PER_CTA
*
THREADS_PER_WARP
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Hmma_tile
{
// The number of elements computed with a single warp-MMA.
enum
{
M_PER_MMA
=
16
,
N_PER_MMA
=
16
,
K_PER_MMA
=
16
};
// The number of elements computed with a single CTA-MMA.
enum
{
M_PER_MMA_PER_CTA
=
M_PER_MMA
*
Cta_tile
::
WARPS_M
,
N_PER_MMA_PER_CTA
=
N_PER_MMA
*
Cta_tile
::
WARPS_N
,
K_PER_MMA_PER_CTA
=
K_PER_MMA
*
Cta_tile
::
WARPS_K
};
// The number of MMAs needed to compute the GEMM.
enum
{
MMAS_M
=
Div_up
<
Cta_tile
::
M
,
M_PER_MMA_PER_CTA
>::
VALUE
,
MMAS_N
=
Div_up
<
Cta_tile
::
N
,
N_PER_MMA_PER_CTA
>::
VALUE
,
MMAS_K
=
Div_up
<
Cta_tile
::
K
,
K_PER_MMA_PER_CTA
>::
VALUE
,
};
// The number of elements computed per warp.
enum
{
M_PER_WARP
=
MMAS_M
*
M_PER_MMA
,
N_PER_WARP
=
MMAS_N
*
N_PER_MMA
,
K_PER_WARP
=
MMAS_K
*
K_PER_MMA
,
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
using
A_type
=
uint16_t
;
using
B_type
=
uint16_t
;
using
C_type
=
uint16_t
;
using
Accumulator_type
=
float
;
using
Epilogue_type
=
float
;
constexpr
int
BITS_PER_ELEMENT_A
=
sizeof
(
A_type
)
*
8
;
constexpr
int
BITS_PER_ELEMENT_B
=
sizeof
(
B_type
)
*
8
;
constexpr
int
BITS_PER_ELEMENT_C
=
sizeof
(
C_type
)
*
8
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
int
N
,
int
K
,
int
WARPS_M
,
int
WARPS_N
,
int
WARPS_K
>
using
Cta_tile_extd
=
Cta_tile_
<
M
,
N
,
K
,
WARPS_M
,
WARPS_N
,
WARPS_K
>
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile_
>
using
Cta_tile_with_k_with_padding
=
Cta_tile_extd
<
Cta_tile_
::
M
,
Cta_tile_
::
N
,
Next_power_of_two
<
Cta_tile_
::
K
>::
VALUE
,
Cta_tile_
::
WARPS_M
,
Cta_tile_
::
WARPS_N
,
Cta_tile_
::
WARPS_K
>
;
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/gmem_tile.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The number of bits per element.
int
BITS_PER_ELEMENT
,
// The number of rows of Q, K or V loaded by this tile.
int
ROWS
,
// The number of columns.
int
COLS
>
struct
Gmem_tile_qkv
{
// The size of each LDG.
enum
{
BYTES_PER_LDG
=
16
};
// The size of a row in bytes.
enum
{
BYTES_PER_ROW
=
COLS
*
BITS_PER_ELEMENT
/
8
};
// The number of threads to load a "row" of the matrix.
enum
{
THREADS_PER_ROW
=
BYTES_PER_ROW
/
BYTES_PER_LDG
};
// The number of "rows" loaded per LDG.
enum
{
ROWS_PER_LDG
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
// The number of LDGs needed to load a chunk of the Q matrix.
enum
{
LDGS
=
fmha
::
Div_up
<
ROWS
,
ROWS_PER_LDG
>::
VALUE
};
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_qkv
(
const
Params
&
params
,
int
qkv_offset
,
const
BInfo
&
binfo
,
int
tidx
)
:
params_qkv_stride_in_bytes_
(
params
.
qkv_stride_in_bytes
)
,
actual_seqlen
(
binfo
.
actual_seqlen
)
,
qkv_ptr_
(
reinterpret_cast
<
char
*>
(
params
.
qkv_ptr
))
{
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
THREADS_PER_ROW
;
// Compute the position of the thread in the row.
int
col
=
tidx
%
THREADS_PER_ROW
;
// Store the row as we need it to disable the loads.
row_
=
row
;
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
int64_t
row_offset
=
(
int64_t
)
row
*
params
.
qkv_stride_in_bytes
;
// Add the block index.
row_offset
+=
(
int64_t
)((
binfo
.
sum_s
*
3
+
qkv_offset
)
*
binfo
.
h
+
binfo
.
bidh
)
*
BYTES_PER_ROW
;
// Assemble the final pointer.
qkv_ptr_
+=
row_offset
+
col
*
BYTES_PER_LDG
;
}
// Store data to shared memory.
template
<
typename
Smem_tile
>
inline
__device__
void
commit
(
Smem_tile
&
smem_tile
)
{
smem_tile
.
store
(
fetch_
);
}
// Load data from memory.
template
<
typename
Smem_tile
>
inline
__device__
void
load
(
Smem_tile
&
smem_tile
)
{
const
void
*
ptrs
[
LDGS
];
uint32_t
preds
[
LDGS
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
ptrs
[
ii
]
=
qkv_ptr_
+
(
int64_t
)
ii
*
ROWS_PER_LDG
*
params_qkv_stride_in_bytes_
;
preds
[
ii
]
=
((
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
));
fetch_
[
ii
]
=
make_uint4
(
0
,
0
,
0
,
0
);
}
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
Ldg_functor
<
uint4
,
LDGS
>
fct
(
fetch_
,
ptrs
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
fct
.
load
(
ii
,
preds
[
ii
]);
}
}
// Store data to memory.
inline
__device__
void
store
(
const
uint4
(
&
data
)[
LDGS
])
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDGS
;
++
ii
)
{
char
*
ptr
=
qkv_ptr_
+
(
int64_t
)
ii
*
ROWS_PER_LDG
*
params_qkv_stride_in_bytes_
;
if
(
(
row_
+
ii
*
ROWS_PER_LDG
)
<
min
(
ROWS
,
actual_seqlen
)
)
{
fmha
::
stg
(
ptr
,
data
[
ii
]);
}
}
}
// Move the pointer to the next location.
inline
__device__
void
move
()
{
qkv_ptr_
+=
(
int64_t
)
ROWS
*
params_qkv_stride_in_bytes_
;
actual_seqlen
-=
ROWS
;
}
// The stride between rows for the QKV matrice.
int64_t
params_qkv_stride_in_bytes_
;
// The pointer.
char
*
qkv_ptr_
;
// The fetch registers.
uint4
fetch_
[
LDGS
];
// Keep track of the row the thread is processing as we move the tile.
int
row_
;
// The length of the sequence loaded by that memory tile.
int
actual_seqlen
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Gmem_tile_o
{
// The mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The size of each element.
enum
{
BYTES_PER_ELEMENT
=
2
};
// The size of a row in bytes.
enum
{
BYTES_PER_ROW
=
Cta_tile
::
N
*
BYTES_PER_ELEMENT
};
// The number of threads to store a "row" of the matrix.
enum
{
THREADS_PER_ROW
=
16
};
// The size of each STG.
enum
{
BYTES_PER_STG
=
BYTES_PER_ROW
/
THREADS_PER_ROW
};
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
enum
{
ROWS
=
Cta_tile
::
M
};
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
enum
{
ROWS_PER_LOOP
=
ROWS
<=
64
?
ROWS
:
(
int
)
Mma_tile
::
M_PER_MMA_PER_CTA
};
// The number of outter loop for the stores.
enum
{
LOOPS
=
ROWS
/
ROWS_PER_LOOP
};
// The number of "rows" stored per STG.
enum
{
ROWS_PER_STG
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
// Do we have to guard against partial writes/reads.
enum
{
HAS_INCOMPLETE_STG
=
Cta_tile
::
M
%
ROWS_PER_STG
!=
0
};
// The number of STGs needed to store a chunk of the Q matrix.
enum
{
STGS_PER_LOOP
=
fmha
::
Div_up
<
ROWS_PER_LOOP
,
ROWS_PER_STG
>::
VALUE
};
// The number of STGs needed to store a chunk of the Q matrix in total.
enum
{
STGS
=
STGS_PER_LOOP
*
LOOPS
};
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_o
(
const
Params
&
params
,
const
BInfo
&
binfo
,
int
tidx
)
:
params_o_stride_in_bytes_
(
params
.
o_stride_in_bytes
)
,
actual_seqlen_
(
binfo
.
actual_seqlen
)
,
o_ptr_
(
reinterpret_cast
<
char
*>
(
params
.
o_ptr
))
{
// Compute the position in the sequence (within the CTA for the moment).
int
row
=
tidx
/
THREADS_PER_ROW
;
// Compute the position of the thread in the row.
int
col
=
tidx
%
THREADS_PER_ROW
;
// Store the row as we need it to disable loads.
row_
=
row
;
// The row offset in the batched GEMM.
int64_t
row_offset
=
(
int64_t
)
row
*
params
.
o_stride_in_bytes
+
binfo
.
bidx
*
BYTES_PER_ROW
;
// Assemble the final pointer.
o_ptr_
+=
row_offset
+
col
*
BYTES_PER_STG
;
// Is that thread active on the last STG?
if
(
HAS_INCOMPLETE_STG
)
{
is_active_for_last_stg_
=
row
+
(
STGS
-
1
)
*
ROWS_PER_STG
<
Cta_tile
::
M
;
}
}
// Store data to global memory.
inline
__device__
void
store
(
const
uint4
(
&
src
)[
STGS_PER_LOOP
],
int
mi
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
STGS_PER_LOOP
;
++
ii
)
{
int
jj
=
mi
*
STGS_PER_LOOP
+
ii
;
if
(
this
->
row_
+
jj
*
ROWS_PER_STG
>=
this
->
actual_seqlen_
)
{
break
;
}
float
x
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
x
);
float
y
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
y
);
float
z
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
z
);
float
w
=
reinterpret_cast
<
const
float
&>
(
src
[
ii
].
w
);
uint2
out
=
float4_to_half4
(
x
,
y
,
z
,
w
);
if
(
!
HAS_INCOMPLETE_STG
||
(
jj
<
STGS
-
1
||
this
->
is_active_for_last_stg_
)
)
{
fmha
::
stg
(
this
->
o_ptr_
+
jj
*
ROWS_PER_STG
*
this
->
params_o_stride_in_bytes_
,
out
);
}
}
}
// Move the pointer to the next location.
inline
__device__
void
move
()
{
row_
+=
ROWS
;
o_ptr_
+=
(
int64_t
)
ROWS
*
params_o_stride_in_bytes_
;
}
// The stride between rows for the QKV matrice.
int64_t
params_o_stride_in_bytes_
;
// The pointer.
char
*
o_ptr_
;
// Is the thread active for the last STG?
int
is_active_for_last_stg_
;
// Keep track of the row to disable loads.
int
row_
;
// The length of the sequence loaded by that memory tile.
int
actual_seqlen_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
int
BYTES_PER_ELEMENT
>
struct
Gmem_tile_mma_sd
{
// The mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// Each STG stores 8 elements.
enum
{
BYTES_PER_STG
=
BYTES_PER_ELEMENT
*
8
};
// The number of MMAs in the M dimension.
enum
{
MMAS_M
=
Mma_tile
::
MMAS_M
};
// The number of MMAs in the N dimension.
enum
{
MMAS_N
=
Mma_tile
::
MMAS_N
};
// The number of rows computed per MMA per thread block.
enum
{
M_PER_MMA_PER_CTA
=
Mma_tile
::
M_PER_MMA_PER_CTA
};
// The number of cols computed per MMA per thread block.
enum
{
N_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
};
// The number of threads per block.
enum
{
THREADS_PER_CTA
=
Cta_tile
::
THREADS_PER_CTA
};
// The size of each row in bytes. I.e. how many bytes are stored per STG.
enum
{
BYTES_PER_ROW
=
THREADS_PER_CTA
*
BYTES_PER_STG
};
// The fixed sequence length.
enum
{
SEQLEN
=
Cta_tile
::
N
};
// The distance between two blocks (in bytes).
enum
{
BLOCK_STRIDE_BYTES
=
SEQLEN
*
SEQLEN
*
BYTES_PER_ELEMENT
};
// The distance between elements stored per loop (in bytes).
enum
{
LOOP_STRIDE_BYTES
=
MMAS_M
*
MMAS_N
*
BYTES_PER_ROW
};
// The type of elements stored per STG.
using
Type
=
typename
fmha
::
Uint_from_size_in_bytes
<
BYTES_PER_STG
>::
Type
;
// Ctor.
template
<
typename
Params
>
inline
__device__
Gmem_tile_mma_sd
(
void
*
ptr
,
const
Params
&
params
,
const
int
tidx
)
:
ptr_
(
static_cast
<
char
*>
(
ptr
))
{
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The block index.
size_t
bidx
=
bidb
*
params
.
h
+
bidh
;
// Set store location for each thread at the beginning of the loop
ptr_
+=
bidx
*
BLOCK_STRIDE_BYTES
+
tidx
*
BYTES_PER_STG
;
}
// Store to global memory.
inline
__device__
void
store
(
const
Type
&
data
,
const
int
mi
,
const
int
ni
)
{
size_t
offset
=
(
mi
*
MMAS_N
+
ni
)
*
BYTES_PER_ROW
;
fmha
::
stg
(
ptr_
+
offset
,
data
);
}
// Load from global memory.
inline
__device__
void
load
(
Type
&
data
,
const
int
mi
,
const
int
ni
)
{
size_t
offset
=
(
mi
*
MMAS_N
+
ni
)
*
BYTES_PER_ROW
;
fmha
::
ldg
(
data
,
ptr_
+
offset
);
}
// Move to the next tile.
inline
__device__
void
move
()
{
ptr_
+=
LOOP_STRIDE_BYTES
;
}
// The pointer in global memory.
char
*
ptr_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Base
=
Gmem_tile_mma_sd
<
Cta_tile
,
sizeof
(
uint16_t
)>
>
struct
Gmem_tile_mma_s
:
public
Base
{
// The number of mmas in the vertical dimension.
enum
{
M
=
Base
::
MMAS_M
};
// The number of mmas in the horizontal dimension.
enum
{
N
=
Base
::
MMAS_N
};
// The type of the vectors stored by each STG.
using
Type
=
typename
Base
::
Type
;
// Ctor.
template
<
typename
Params
>
inline
__device__
Gmem_tile_mma_s
(
void
*
ptr
,
const
Params
&
params
,
const
int
tidx
)
:
Base
(
ptr
,
params
,
tidx
)
{
}
// Store to global memory.
template
<
typename
Mask
>
inline
__device__
void
store
(
const
float
(
&
softmax
)[
2
*
M
][
4
*
N
],
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
float
tmp00
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
0
];
float
tmp01
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
1
];
float
tmp02
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
2
];
float
tmp03
=
softmax
[
2
*
mi
+
0
][
4
*
ni
+
3
];
float
tmp10
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
0
];
float
tmp11
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
1
];
float
tmp12
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
2
];
float
tmp13
=
softmax
[
2
*
mi
+
1
][
4
*
ni
+
3
];
uint4
dst
;
dst
.
x
=
fmha
::
float2_to_half2
(
tmp00
,
tmp01
);
dst
.
y
=
fmha
::
float2_to_half2
(
tmp02
,
tmp03
);
dst
.
z
=
fmha
::
float2_to_half2
(
tmp10
,
tmp11
);
dst
.
w
=
fmha
::
float2_to_half2
(
tmp12
,
tmp13
);
if
(
mask
.
is_valid
(
mi
,
ni
,
0
,
0
)
)
{
Base
::
store
(
dst
,
mi
,
ni
);
}
}
}
}
// Load from global memory.
template
<
typename
Mask
>
inline
__device__
void
load
(
uint4
(
&
regs
)[
M
][
N
],
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
regs
[
mi
][
ni
]
=
make_uint4
(
0
,
0
,
0
,
0
);
if
(
mask
.
is_valid
(
mi
,
ni
,
0
,
0
)
)
{
Base
::
load
(
regs
[
mi
][
ni
],
mi
,
ni
);
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The base class.
typename
Base
=
fmha
::
Gmem_tile_qkv
<
Cta_tile
,
fmha
::
BITS_PER_ELEMENT_A
,
Cta_tile
::
M
,
Cta_tile
::
K
>
>
struct
Gmem_tile_dout
:
public
Base
{
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_dout
(
const
Params
&
params
,
const
BInfo
&
binfo
,
int
tidx
)
:
Base
(
params
,
0
,
binfo
,
tidx
)
{
this
->
qkv_ptr_
=
reinterpret_cast
<
char
*>
(
params
.
o_ptr
);
this
->
params_qkv_stride_in_bytes_
=
params
.
o_stride_in_bytes
;
// needed for move
// Compute the position of the thread in the row.
int
col
=
tidx
%
Base
::
THREADS_PER_ROW
;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
int64_t
row_offset
=
(
int64_t
)
this
->
row_
*
params
.
o_stride_in_bytes
+
binfo
.
bidx
*
Base
::
BYTES_PER_ROW
;
// Assemble the final pointer.
this
->
qkv_ptr_
+=
row_offset
+
col
*
Base
::
BYTES_PER_LDG
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Base
=
fmha
::
Gmem_tile_o
<
Cta_tile
>
>
struct
Gmem_tile_dq
:
public
Base
{
// Ctor.
template
<
typename
Params
,
typename
BInfo
>
inline
__device__
Gmem_tile_dq
(
const
Params
&
params
,
const
BInfo
&
binfo
,
int
tidx
)
:
Base
(
params
,
binfo
,
tidx
)
{
this
->
o_ptr_
=
reinterpret_cast
<
char
*>
(
params
.
dqkv_ptr
);
this
->
params_o_stride_in_bytes_
=
params
.
qkv_stride_in_bytes
;
// needed for move
// Compute the position of the thread in the row.
int
col
=
tidx
%
Base
::
THREADS_PER_ROW
;
// The row offset in the batched GEMM. For each seq element, we store O in that order.
int64_t
row_offset
=
(
int64_t
)
this
->
row_
*
params
.
qkv_stride_in_bytes
+
(
binfo
.
sum_s
*
3
*
binfo
.
h
+
binfo
.
bidh
)
*
Base
::
BYTES_PER_ROW
;
// Assemble the final pointer.
this
->
o_ptr_
+=
row_offset
+
col
*
Base
::
BYTES_PER_STG
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/kernel_traits.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
S
,
int
D
,
int
STEP
,
int
WARPS_M
,
int
WARPS_N
,
uint32_t
FLAGS
=
0x8u
>
struct
FMHA_kernel_traits
{
// The CTA description for the 1st GEMM.
using
Cta_tile_p
=
fmha
::
Cta_tile_extd
<
STEP
,
S
,
D
,
WARPS_M
,
WARPS_N
,
1
>
;
// The CTA description for the 2nd GEMM.
using
Cta_tile_o
=
fmha
::
Cta_tile_extd
<
STEP
,
D
,
S
,
WARPS_M
,
1
,
WARPS_N
>
;
// Do we use one buffer for K and V.
enum
{
SHARE_SMEM_FOR_K_AND_V
=
(
FLAGS
&
0x8u
)
!=
0u
};
// The global memory tile to load Q.
using
Gmem_tile_q
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_A
,
STEP
,
D
>
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
fmha
::
Smem_tile_a
<
Cta_tile_p
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
1
>
;
// The global memory tile to load K.
using
Gmem_tile_k
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_p
,
fmha
::
BITS_PER_ELEMENT_B
,
S
,
D
>
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
fmha
::
Smem_tile_b
<
Cta_tile_p
,
fmha
::
Col
>
;
// The global memory tile to load V.
using
Gmem_tile_v
=
fmha
::
Gmem_tile_qkv
<
Cta_tile_o
,
fmha
::
BITS_PER_ELEMENT_B
,
S
,
D
>
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
fmha
::
Smem_tile_v
<
Cta_tile_o
>
;
// The global memory tile to store O.
using
Gmem_tile_o
=
fmha
::
Gmem_tile_o
<
Cta_tile_o
>
;
// The shared memory tile for O.
using
Smem_tile_o
=
fmha
::
Smem_tile_o
<
Cta_tile_o
>
;
// The global memory tile to load/store S.
using
Gmem_tile_s
=
fmha
::
Gmem_tile_mma_s
<
Cta_tile_p
>
;
// The shared memory tile to transpose S.
using
Smem_tile_st
=
fmha
::
Smem_tile_mma_transposed
<
Cta_tile_p
>
;
using
Gmem_tile_do
=
fmha
::
Gmem_tile_dout
<
Cta_tile_p
>
;
// Make sure the number of threads match.
static_assert
((
int
)
Gmem_tile_o
::
THREADS_PER_ROW
==
(
int
)
Smem_tile_o
::
THREADS_PER_ROW
,
""
);
// The number of threads.
enum
{
THREADS
=
Cta_tile_p
::
THREADS_PER_CTA
};
// Make sure the number of threads matches both CTAs.
static_assert
((
int
)
THREADS
==
(
int
)
Cta_tile_o
::
THREADS_PER_CTA
,
""
);
// The amount of shared memory needed to load Q and K.
enum
{
BYTES_PER_SMEM_QK
=
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
};
// The extra amount of shared memory needed to load V.
enum
{
BYTES_PER_SMEM_V
=
SHARE_SMEM_FOR_K_AND_V
?
0u
:
Smem_tile_v
::
BYTES_PER_TILE
};
// The amount of shared memory needed for Q, K and V..
enum
{
BYTES_PER_SMEM_QKV
=
BYTES_PER_SMEM_QK
+
BYTES_PER_SMEM_V
};
// The amount of shared memory needed to load Q and store O.
enum
{
BYTES_PER_SMEM_QO
=
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
};
// The amount of shared memory needed for Q, K, V and O.
enum
{
BYTES_PER_SMEM
=
fmha
::
Max
<
BYTES_PER_SMEM_QKV
,
BYTES_PER_SMEM_QO
>::
VALUE
};
// Make sure we have enough shared memory.
static_assert
(
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
<=
BYTES_PER_SMEM
,
""
);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
apex/contrib/csrc/fmha/src/fmha/mask.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace
fmha
{
template
<
typename
Cta_tile
>
struct
Mask
{
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
template
<
typename
Params
,
typename
BInfo
>
__device__
Mask
(
const
Params
&
params
,
const
BInfo
&
blockInfo
,
int
tidx
)
{
actual_seqlen
=
blockInfo
.
actual_seqlen
;
const
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
const
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
static_assert
(
Cta_tile
::
WARPS_K
==
1
,
""
);
// find the warp in the Cta tile
const
int
warp_n
=
(
warp
/
Cta_tile
::
WARPS_M
);
const
int
warp_m
=
(
warp
%
Cta_tile
::
WARPS_M
);
// decompose warp into 8x4 tile
const
int
quad
=
lane
/
4
;
const
int
tid
=
(
lane
%
4
)
*
2
;
row
=
warp_m
*
16
+
quad
;
col
=
warp_n
*
16
+
tid
;
}
inline
__device__
bool
is_valid
(
const
int
mi
,
const
int
ni
,
const
int
ii
,
const
int
jj
)
const
{
// ii and jj iterate over the 2x4 fragment
const
bool
col_valid
=
(
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
+
col
+
(
jj
&
2
)
*
4
+
(
jj
&
1
))
<
actual_seqlen
;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen;
return
col_valid
;
// return row_valid && col_valid;
}
inline
__device__
void
load
(
int
it
)
{
row_offset
=
it
*
Cta_tile
::
M
+
row
;
}
int
row_offset
;
int
row
;
int
col
;
int
actual_seqlen
;
};
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/smem_tile.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <fmha/utils.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The description of the tile computed by this CTA.
typename
Cta_tile
,
// The number of rows in the 2D shared memory buffer.
int
M_
,
// The number of cols.
int
N_
,
// The size in bits of each element.
int
BITS_PER_ELEMENT_
,
// The number of bytes per STS.
int
BYTES_PER_STS_
=
16
,
// The number of buffers. (Used in multistage and double buffer cases.)
int
BUFFERS_PER_TILE_
=
1
,
// Do we enable the fast path for LDS.128 and friends.
int
ENABLE_LDS_FAST_PATH_
=
0
,
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
int
ROWS_PER_XOR_PATTERN_
=
8
,
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
int
COLS_PER_XOR_PATTERN_
=
1
,
// Use or not predicates
bool
USE_PREDICATES_
=
true
>
struct
Smem_tile_without_skews
{
// The size in bits of each element.
enum
{
BITS_PER_ELEMENT
=
BITS_PER_ELEMENT_
};
// The size in bytes of a single STS.
enum
{
BYTES_PER_STS
=
BYTES_PER_STS_
};
// The number of elements per STS.
enum
{
ELEMENTS_PER_STS
=
BYTES_PER_STS
*
8
/
BITS_PER_ELEMENT
};
// To support arbitrary N, we pad some values to a power-of-2.
enum
{
N_WITH_PADDING
=
Next_power_of_two
<
N_
>::
VALUE
};
// The number of bytes per row without packing of rows.
enum
{
BYTES_PER_ROW_BEFORE_PACKING
=
N_WITH_PADDING
*
BITS_PER_ELEMENT
/
8
};
// The number of bytes per row -- we want at least 128B per row.
enum
{
BYTES_PER_ROW
=
Max
<
BYTES_PER_ROW_BEFORE_PACKING
,
128
>::
VALUE
};
// The number of rows in shared memory (two rows may be packed into a single one).
enum
{
ROWS
=
M_
*
BYTES_PER_ROW_BEFORE_PACKING
/
BYTES_PER_ROW
};
// The number of threads per row.
enum
{
THREADS_PER_ROW_UNBOUNDED
=
BYTES_PER_ROW
/
BYTES_PER_STS
};
// The number of threads per row.
enum
{
THREADS_PER_ROW
=
Min
<
Cta_tile
::
THREADS_PER_CTA
,
THREADS_PER_ROW_UNBOUNDED
>::
VALUE
};
// The number of STS per row.
enum
{
STS_PER_ROW
=
BYTES_PER_ROW
/
THREADS_PER_ROW
/
BYTES_PER_STS
};
// It must be at least one.
static_assert
(
STS_PER_ROW
>=
1
,
""
);
// The number of rows written with a single STS.
enum
{
ROWS_PER_STS
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
// Make sure we write to at least one row per STS. Thanks Dr. Obvious ;)
static_assert
(
ROWS_PER_STS
>=
1
,
""
);
// The number of STS needed to store all rows.
enum
{
STS_PER_COL
=
Div_up
<
ROWS
,
ROWS_PER_STS
>::
VALUE
};
// The number of STS in total.
enum
{
STS
=
STS_PER_COL
*
STS_PER_ROW
};
// The size of one buffer in bytes in shared memory.
enum
{
BYTES_PER_BUFFER
=
STS
*
BYTES_PER_STS
*
Cta_tile
::
THREADS_PER_CTA
};
// The number of buffers.
enum
{
BUFFERS_PER_TILE
=
BUFFERS_PER_TILE_
};
// The size in bytes of total buffers.
enum
{
BYTES_PER_TILE
=
BYTES_PER_BUFFER
*
BUFFERS_PER_TILE
};
// The boundary for smem_read_offset and smem_write_offset increment.
enum
{
BYTES_PER_TILE_INC_BOUNDARY
=
BYTES_PER_TILE
-
BYTES_PER_BUFFER
};
// Do we enable the LDS.128 fast path?
enum
{
ENABLE_LDS_FAST_PATH
=
ENABLE_LDS_FAST_PATH_
};
static_assert
(
ENABLE_LDS_FAST_PATH
==
0
);
// The number of rows that are used for the XOR swizzling to allow fast STS/LDS.
enum
{
ROWS_PER_XOR_PATTERN
=
ROWS_PER_XOR_PATTERN_
};
// The number of cols that are used for the XOR swizzling to allow fast STS/LDS.
enum
{
COLS_PER_XOR_PATTERN
=
COLS_PER_XOR_PATTERN_
*
16
/
BYTES_PER_STS
};
// Use or not predicates
enum
{
USE_PREDICATES
=
USE_PREDICATES_
};
// The type of elements that are stored in shared memory by each thread.
using
Store_type
=
typename
Uint_from_size_in_bytes
<
BYTES_PER_STS
>::
Type
;
// Ctor.
inline
__device__
Smem_tile_without_skews
(
void
*
smem
,
int
tidx
)
:
smem_
(
__nvvm_get_smem_pointer
(
smem
))
{
// The row written by a thread. See doc/mma_smem_layout.xlsx.
int
smem_write_row
=
tidx
/
THREADS_PER_ROW
;
// The XOR pattern.
int
smem_write_xor
=
smem_write_row
%
ROWS_PER_XOR_PATTERN
*
COLS_PER_XOR_PATTERN
;
// Compute the column and apply the XOR pattern.
int
smem_write_col
=
(
tidx
%
THREADS_PER_ROW
)
^
smem_write_xor
;
// The offset.
this
->
smem_write_offset_
=
smem_write_row
*
BYTES_PER_ROW
+
smem_write_col
*
BYTES_PER_STS
;
// TODO: Why not merge it with the read offset?
this
->
smem_read_buffer_
=
__shfl_sync
(
0xffffffff
,
0
,
0
);
this
->
smem_write_buffer_
=
__shfl_sync
(
0xffffffff
,
0
,
0
);
}
// Compute the store pointers.
template
<
int
N
>
inline
__device__
void
compute_store_pointers
(
uint32_t
(
&
ptrs
)[
N
])
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
// Decompose the STS into row/col.
int
row
=
ii
/
STS_PER_ROW
;
int
col
=
ii
%
STS_PER_ROW
;
// Assemble the offset.
int
offset
=
smem_write_offset_
+
row
*
ROWS_PER_STS
*
BYTES_PER_ROW
;
// Take the column into account.
if
(
STS_PER_ROW
>
1
)
{
offset
+=
col
*
THREADS_PER_ROW
*
BYTES_PER_STS
;
}
// Apply the XOR pattern if needed.
if
(
ROWS_PER_STS
<
ROWS_PER_XOR_PATTERN
)
{
const
int
m
=
row
*
ROWS_PER_STS
%
ROWS_PER_XOR_PATTERN
;
offset
^=
m
*
COLS_PER_XOR_PATTERN
*
BYTES_PER_STS
;
}
// Assemble the final pointer :)
ptrs
[
ii
]
=
smem_
+
offset
+
smem_write_buffer_
;
}
}
inline
__device__
void
debug_reset
()
{
for
(
int
buffer
=
0
;
buffer
<
BYTES_PER_TILE
;
buffer
+=
BYTES_PER_BUFFER
)
{
for
(
int
row
=
0
;
row
<
ROWS
;
++
row
)
{
for
(
int
col
=
0
;
col
<
BYTES_PER_ROW
;
col
+=
4
)
{
if
(
threadIdx
.
x
==
0
)
{
uint32_t
val
=
0x0
;
sts
(
val
,
smem_
+
row
*
BYTES_PER_ROW
+
col
+
buffer
);
}
}
}
}
}
// Print the content of the tile (only for debug ;)).
inline
__device__
void
debug_print
()
const
{
for
(
int
buffer
=
0
;
buffer
<
BYTES_PER_TILE
;
buffer
+=
BYTES_PER_BUFFER
)
{
for
(
int
row
=
0
;
row
<
ROWS
;
++
row
)
{
for
(
int
col
=
0
;
col
<
BYTES_PER_ROW
;
col
+=
4
)
{
if
(
threadIdx
.
x
==
0
)
{
uint32_t
val
;
lds
(
val
,
smem_
+
row
*
BYTES_PER_ROW
+
col
+
buffer
);
printf
(
"block=(x=%2d, y=%2d, z=%2d) (smem_=%2d, buffer=%2d, row=%2d, byte=%4d)=0x%08x
\n
"
,
blockIdx
.
x
,
blockIdx
.
y
,
blockIdx
.
z
,
smem_
,
buffer
,
row
,
col
,
val
);
}
}
}
}
}
// Move the read offset to next buffer.
inline
__device__
void
move_to_next_read_buffer
()
{
if
(
BUFFERS_PER_TILE
>
1
&&
smem_read_buffer_
>=
BYTES_PER_TILE_INC_BOUNDARY
)
{
this
->
smem_read_buffer_
-=
BYTES_PER_TILE_INC_BOUNDARY
;
}
else
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_read_buffer_
+=
BYTES_PER_BUFFER
;
}
}
// Move the read offset to next buffer. TODO: Remove this member function!!!
inline
__device__
void
move_next_read_buffer
()
{
this
->
move_to_next_read_buffer
();
}
// Move the read offset to next N buffer (circular-buffer).
inline
__device__
void
move_to_next_read_buffer
(
int
N
)
{
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_read_buffer_
+=
N
*
BYTES_PER_BUFFER
;
this
->
smem_read_buffer_
-=
smem_read_buffer_
>=
BYTES_PER_TILE
?
BYTES_PER_TILE
:
0
;
}
}
// Move the read offset to next N buffer (circular-buffer). TODO: Remove this member function!!!
inline
__device__
void
move_next_read_buffer
(
int
N
)
{
this
->
move_to_next_read_buffer
(
N
);
}
// Move the write offset to next buffer.
inline
__device__
void
move_to_next_write_buffer
()
{
if
(
BUFFERS_PER_TILE
>
1
&&
smem_write_buffer_
>=
BYTES_PER_TILE_INC_BOUNDARY
)
{
this
->
smem_write_buffer_
-=
BYTES_PER_TILE_INC_BOUNDARY
;
}
else
if
(
BUFFERS_PER_TILE
>
1
)
{
this
->
smem_write_buffer_
+=
BYTES_PER_BUFFER
;
}
}
// Move the write offset to next buffer. TODO: Remove that member function!
inline
__device__
void
move_next_write_buffer
()
{
this
->
move_to_next_write_buffer
();
}
// Move the read offset.
inline
__device__
void
move_read_offset
(
int
delta
)
{
this
->
smem_read_offset_
+=
delta
;
}
// Move the write offset.
inline
__device__
void
move_write_offset
(
int
delta
)
{
this
->
smem_write_offset_
+=
delta
;
}
// Store to the tile in shared memory.
template
<
int
N
>
inline
__device__
void
store
(
const
Store_type
(
&
data
)[
N
],
uint64_t
=
0
)
{
uint32_t
smem_ptrs
[
N
];
this
->
compute_store_pointers
(
smem_ptrs
);
sts
(
smem_ptrs
,
data
);
}
// Store to the tile in shared memory.
template
<
int
N
,
int
M
>
inline
__device__
void
store
(
const
Store_type
(
&
data
)[
N
],
uint32_t
(
&
preds
)[
M
],
uint64_t
=
0
)
{
uint32_t
smem_ptrs
[
N
];
this
->
compute_store_pointers
(
smem_ptrs
);
sts
(
smem_ptrs
,
data
,
preds
);
}
// Store to the tile in shared memory.
template
<
int
N
>
inline
__device__
void
store
(
const
Store_type
(
&
data
)[
N
],
uint32_t
preds
,
uint64_t
=
0
)
{
this
->
store
(
data
,
preds
);
}
// Store to the tile in shared memory.
template
<
int
N
>
inline
__device__
void
store
(
const
void
*
(
&
gmem_ptrs
)[
N
],
uint32_t
preds
,
uint64_t
=
0
)
{
uint32_t
tmp
[
1
]
=
{
preds
};
this
->
store
(
gmem_ptrs
,
tmp
);
}
// The shared memory pointer.
uint32_t
smem_
;
// The read offset. Reserve 4 offsets if needed.
int
smem_read_offset_
;
// The write offset.
int
smem_write_offset_
;
// The buffer base offset for read.
int
smem_read_buffer_
;
// The buffer base offset for write.
int
smem_write_buffer_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The layout of the tile.
typename
Layout
,
// The size of the STS.
int
BYTES_PER_STS
=
16
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
=
1
,
// Use or not predicates
bool
USE_PREDICATES
=
true
>
struct
Smem_tile_a
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_K
,
int
MMAS_K_WITH_PADDING
>
struct
Compute_reset_mask
{
// The potential mask.
enum
{
HALF
=
MMAS_K_WITH_PADDING
/
2
};
// The remainder.
enum
{
MOD
=
MMAS_K
%
HALF
};
// The final value.
enum
{
VALUE
=
(
MMAS_K
==
MOD
?
0
:
HALF
)
|
Compute_reset_mask
<
MOD
,
HALF
>::
VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_K_WITH_PADDING
>
struct
Compute_reset_mask
<
0
,
MMAS_K_WITH_PADDING
>
{
enum
{
VALUE
=
0
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
MMAS_K
>
struct
Compute_reset_mask
<
MMAS_K
,
MMAS_K
>
{
enum
{
VALUE
=
MMAS_K
-
1
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_a
{
// The size in bits.
enum
{
N_IN_BITS
=
N
*
fmha
::
BITS_PER_ELEMENT_A
};
// The number of rows.
enum
{
VALUE
=
N_IN_BITS
<=
256
?
2
:
(
N_IN_BITS
<=
512
?
4
:
8
)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_row_a
:
public
Rows_per_xor_pattern_a
<
N
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int
ROWS_PER_XOR_PATTERN_
=
Rows_per_xor_pattern_row_a
<
Cta_tile
::
K
>
::
VALUE
>
struct
Smem_tile_row_a
:
public
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
M
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_A
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
M
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_A
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
;
// The fragment.
using
Fragment
=
Fragment_a
<
Row
>
;
// When we use padding to reach a power of two, special care has to be taken.
using
Cta_tile_with_padding
=
Cta_tile_with_k_with_padding
<
Cta_tile
>
;
// The number of MMAs.
using
Mma_tile_with_padding
=
fmha
::
Hmma_tile
<
Cta_tile_with_padding
>
;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
16
};
// Ctor.
inline
__device__
Smem_tile_row_a
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
const
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
WARPS_M
==
1
);
static_assert
(
WARPS_N
==
4
||
WARPS_N
==
8
);
static_assert
(
WARPS_K
==
1
);
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
// The row and column read by the thread.
int
smem_read_row
=
(
tidx
&
0x0f
);
int
smem_read_col
=
(
tidx
&
0x07
);
smem_read_col
^=
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline
__device__
void
reverse_smem_read_offset
(
int
ki
=
0
)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
a
)[
Mma_tile
::
MMAS_M
],
int
ki
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile
::
MMAS_M
;
++
mi
)
{
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int
offset
=
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
// Load using LDSM.M88.4.
uint4
tmp
;
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
);
// Store the value into the fragment.
a
[
mi
].
reg
(
0
)
=
tmp
.
x
;
a
[
mi
].
reg
(
1
)
=
tmp
.
y
;
a
[
mi
].
reg
(
2
)
=
tmp
.
z
;
a
[
mi
].
reg
(
3
)
=
tmp
.
w
;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
this
->
smem_read_offset_
^=
31
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
16
&&
ki
%
8
==
7
)
{
this
->
smem_read_offset_
^=
15
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
8
&&
ki
%
4
==
3
)
{
this
->
smem_read_offset_
^=
7
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
4
&&
ki
%
2
==
1
)
{
this
->
smem_read_offset_
^=
3
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
}
}
// Reset the read offset.
inline
__device__
void
reset_read_offset
()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K
=
Mma_tile
::
MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING
=
Mma_tile_with_padding
::
MMAS_K
};
// Assemble the mask.
enum
{
MASK
=
Compute_reset_mask
<
MMAS_K
,
MMAS_K_WITH_PADDING
>::
VALUE
};
// Reset the read offset.
this
->
smem_read_offset_
^=
MASK
*
BYTES_PER_LDS
*
2
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
>
struct
Smem_tile_a
<
Cta_tile
,
Row
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
:
public
Smem_tile_row_a
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
{
// The base class.
using
Base
=
Smem_tile_row_a
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
;
// Ctor.
inline
__device__
Smem_tile_a
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The layout of the tile.
typename
Layout
,
// The size of the STS.
int
BYTES_PER_STS
=
16
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
=
1
,
// Use or not predicates
bool
USE_PREDICATES
=
true
>
struct
Smem_tile_b
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_b
{
// The size in bits.
enum
{
N_IN_BITS
=
N
*
fmha
::
BITS_PER_ELEMENT_B
};
// The number of rows.
enum
{
VALUE
=
N_IN_BITS
<=
256
?
2
:
(
N_IN_BITS
<=
512
?
4
:
8
)
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_col_b
:
public
Rows_per_xor_pattern_b
<
N
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int
ROWS_PER_XOR_PATTERN_
=
Rows_per_xor_pattern_col_b
<
Cta_tile
::
K
>
::
VALUE
>
struct
Smem_tile_col_b
:
public
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
N
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
N
,
Cta_tile
::
K
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
1
>
;
// The fragment.
using
Fragment
=
Fragment_b
<
Col
>
;
// When we use padding to reach a power of two, special care has to be taken.
using
Cta_tile_with_padding
=
Cta_tile_with_k_with_padding
<
Cta_tile
>
;
// The number of MMAs.
using
Mma_tile_with_padding
=
fmha
::
Hmma_tile
<
Cta_tile_with_padding
>
;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
16
};
// The number of STS per thread
enum
{
STS_PER_THREAD_
=
Base
::
ROWS
*
Base
::
THREADS_PER_ROW
/
Cta_tile
::
THREADS_PER_CTA
};
// The number of STS per thread must be at least 1.
enum
{
STS_PER_THREAD
=
Max
<
1
,
STS_PER_THREAD_
>::
VALUE
};
// Ctor.
inline
__device__
Smem_tile_col_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// For documentation on the layout, see doc/mma_smem_layout.xlsx.
// The number of warps.
const
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
const
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
static_assert
(
WARPS_M
==
1
);
static_assert
(
WARPS_N
==
4
||
WARPS_N
==
8
);
static_assert
(
WARPS_K
==
1
);
// The masks to select the warps.
const
int
WARP_MASK_N
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
N
;
// The divisor for the warps.
const
int
WARP_DIV_N
=
WARPS_M
*
1
*
Cta_tile
::
THREADS_PER_WARP
;
// The row and column read by the thread.
int
smem_read_row
=
(
tidx
&
WARP_MASK_N
)
/
WARP_DIV_N
*
Mma_tile
::
N_PER_MMA
+
(
tidx
&
0x07
)
+
(
tidx
&
0x10
)
/
2
;
int
smem_read_col
=
(
tidx
&
0x07
);
smem_read_col
^=
(
tidx
&
0x08
)
/
8
;
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline
__device__
void
reverse_smem_read_offset
(
int
ki
=
0
)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Jump by as many matrix rows as needed (a row in smem may pack multiple matrix rows).
int
offset
=
ni
*
Mma_tile
::
N_PER_MMA_PER_CTA
*
Base
::
BYTES_PER_ROW_BEFORE_PACKING
;
// Load using LDSM.M88.4.
uint4
tmp
;
ldsm
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
this
->
smem_read_buffer_
+
offset
);
// Store the value into the fragment.
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
}
// Move the offset to the next possition. See doc/mma_smem_layout.xlsx.
static_assert
(
Mma_tile_with_padding
::
MMAS_K
<
64
,
"Not implemented"
);
if
(
Mma_tile_with_padding
::
MMAS_K
>=
32
&&
ki
%
16
==
15
)
{
this
->
smem_read_offset_
^=
31
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
16
&&
ki
%
8
==
7
)
{
this
->
smem_read_offset_
^=
15
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
8
&&
ki
%
4
==
3
)
{
this
->
smem_read_offset_
^=
7
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
4
&&
ki
%
2
==
1
)
{
this
->
smem_read_offset_
^=
3
*
BYTES_PER_LDS
*
2
;
}
else
if
(
Mma_tile_with_padding
::
MMAS_K
>=
2
)
{
this
->
smem_read_offset_
^=
1
*
BYTES_PER_LDS
*
2
;
}
}
// Reset the read offset.
inline
__device__
void
reset_read_offset
()
{
// The number of MMAs in the K dimension.
enum
{
MMAS_K
=
Mma_tile
::
MMAS_K
};
// The number of MMAs in the K dimension when we include padding.
enum
{
MMAS_K_WITH_PADDING
=
Mma_tile_with_padding
::
MMAS_K
};
// Assemble the mask.
enum
{
MASK
=
Compute_reset_mask
<
MMAS_K
,
MMAS_K_WITH_PADDING
>::
VALUE
};
// Reset the read offset.
this
->
smem_read_offset_
^=
MASK
*
BYTES_PER_LDS
*
2
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
>
struct
Smem_tile_b
<
Cta_tile
,
Col
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
:
public
Smem_tile_col_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
{
// The base class.
using
Base
=
Smem_tile_col_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
;
// Ctor.
inline
__device__
Smem_tile_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
struct
Rows_per_xor_pattern_row_b
:
public
Rows_per_xor_pattern_b
<
N
>
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
,
// How many rows to use for the XOR pattern to avoid bank conflicts?
int
ROWS_PER_XOR_PATTERN_
=
Rows_per_xor_pattern_row_b
<
Cta_tile
::
N
>
::
VALUE
,
// How many cols to use for the XOR pattern to avoid bank conflicts?
int
COLS_PER_XOR_PATTERN_
=
1
>
struct
Smem_tile_row_b
:
public
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
COLS_PER_XOR_PATTERN_
>
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
fmha
::
BITS_PER_ELEMENT_B
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
,
0
,
ROWS_PER_XOR_PATTERN_
,
COLS_PER_XOR_PATTERN_
>
;
// The fragment.
using
Fragment
=
Fragment_b
<
Row
>
;
// Can we use LDSM? No if the data type is 32-bit large.
enum
{
USE_LDSMT
=
fmha
::
BITS_PER_ELEMENT_B
==
16
};
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
USE_LDSMT
?
16
:
4
};
// The number of elements per LDS.
enum
{
ELEMENTS_PER_LDS
=
BYTES_PER_LDS
*
8
/
fmha
::
BITS_PER_ELEMENT_B
};
// The number of STS per thread
enum
{
STS_PER_THREAD_
=
Base
::
ROWS
*
Base
::
THREADS_PER_ROW
/
Cta_tile
::
THREADS_PER_CTA
};
// The number of STS per thread must be at least 1.
enum
{
STS_PER_THREAD
=
Max
<
1
,
STS_PER_THREAD_
>::
VALUE
};
// Ctor.
inline
__device__
Smem_tile_row_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// The number of warps.
const
int
WARPS_M
=
Cta_tile
::
WARPS_M
;
const
int
WARPS_N
=
Cta_tile
::
WARPS_N
;
const
int
WARPS_K
=
Cta_tile
::
WARPS_K
;
static_assert
(
WARPS_K
==
1
);
static_assert
(
WARPS_M
==
4
||
WARPS_M
==
8
);
static_assert
(
WARPS_N
==
1
);
// The masks to select the warps.
const
int
WARP_MASK_N
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
N
;
const
int
WARP_MASK_K
=
Warp_masks
<
WARPS_M
,
WARPS_N
,
WARPS_K
>::
K
;
// The divisor for the warps.
const
int
WARP_DIV_N
=
WARPS_M
*
1
*
Cta_tile
::
THREADS_PER_WARP
;
const
int
WARP_DIV_K
=
WARPS_M
*
WARPS_N
*
Cta_tile
::
THREADS_PER_WARP
;
// The row/col read by the thread.
int
smem_read_row
,
smem_read_col
;
static_assert
(
USE_LDSMT
);
static_assert
(
Base
::
ROWS_PER_XOR_PATTERN
==
8
);
smem_read_row
=
(
tidx
&
WARP_MASK_K
)
/
WARP_DIV_K
*
Mma_tile
::
MMAS_K
*
16
+
(
tidx
&
0x07
)
+
(
tidx
&
0x08
);
smem_read_col
=
(
tidx
&
0x07
);
smem_read_col
^=
(
tidx
&
WARP_MASK_N
)
/
WARP_DIV_N
*
2
+
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
smem_read_row
*
Base
::
BYTES_PER_ROW
+
smem_read_col
*
BYTES_PER_LDS
;
// Fill zeroes for group conv
}
// Rewind smem_read_offset for last LDS phase in main loop.
inline
__device__
void
reverse_smem_read_offset
(
int
ki
=
0
)
{
// The size of each element in bits.
const
int
BITS_PER_ELT
=
fmha
::
BITS_PER_ELEMENT_B
;
// The size in bytes of the data needed to compute an MMA per CTA.
const
int
BYTES_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
*
BITS_PER_ELT
/
8
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Undo the pointer increment for the next ni.
// Should match the load function below for ki = 0.
if
(
BYTES_PER_MMA_PER_CTA
>=
128
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
&&
Mma_tile
::
MMAS_N
%
2
==
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
// The size of each element in bits.
const
int
BITS_PER_ELT
=
fmha
::
BITS_PER_ELEMENT_B
;
// The size in bytes of the data needed to compute an MMA per CTA.
const
int
BYTES_PER_MMA_PER_CTA
=
Mma_tile
::
N_PER_MMA_PER_CTA
*
BITS_PER_ELT
/
8
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Prepare the offset.
int
offset
=
ki
*
Base
::
ROWS_PER_XOR_PATTERN
*
2
*
Base
::
BYTES_PER_ROW
;
if
(
BYTES_PER_MMA_PER_CTA
==
32
)
{
offset
+=
this
->
smem_read_offset_
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
offset
+=
this
->
smem_read_offset_
+
(
ni
/
2
)
*
BYTES_PER_MMA_PER_CTA
*
2
;
}
else
{
offset
+=
this
->
smem_read_offset_
+
(
ni
)
*
BYTES_PER_MMA_PER_CTA
;
}
// Load the data using LDSM.MT88.2.
uint32_t
ptr
=
this
->
smem_
+
this
->
smem_read_buffer_
+
offset
;
uint4
tmp
;
if
(
USE_LDSMT
)
{
ldsmt
(
tmp
,
ptr
);
}
else
{
lds
(
tmp
.
x
,
(
ptr
)
+
0
*
Base
::
BYTES_PER_ROW
);
lds
(
tmp
.
y
,
(
ptr
)
+
4
*
Base
::
BYTES_PER_ROW
);
lds
(
tmp
.
z
,
(
ptr
^
32
)
+
0
*
Base
::
BYTES_PER_ROW
);
lds
(
tmp
.
w
,
(
ptr
^
32
)
+
4
*
Base
::
BYTES_PER_ROW
);
}
// Store those values in the fragment.
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if
(
BYTES_PER_MMA_PER_CTA
>=
128
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
64
)
{
// Nothing to do!
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
if
(
BYTES_PER_MMA_PER_CTA
==
32
&&
Mma_tile
::
MMAS_N
==
2
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
2
;
}
}
// Reset smem_read_offset for odd MMAS_N > 1 (npo2 kernels)
if
(
BYTES_PER_MMA_PER_CTA
==
64
&&
Mma_tile
::
MMAS_N
>
1
&&
Mma_tile
::
MMAS_N
%
2
==
1
)
{
this
->
smem_read_offset_
^=
BYTES_PER_MMA_PER_CTA
;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The dimensions of the tile computed by the CTA.
typename
Cta_tile
,
// The size of the STS.
int
BYTES_PER_STS
,
// The number of buffers per tile.
int
BUFFERS_PER_TILE
>
struct
Smem_tile_b
<
Cta_tile
,
Row
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
:
public
Smem_tile_row_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
{
// The base class.
using
Base
=
Smem_tile_row_b
<
Cta_tile
,
BYTES_PER_STS
,
BUFFERS_PER_TILE
>
;
// Ctor.
inline
__device__
Smem_tile_b
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_v
:
public
fmha
::
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
16
,
16
,
1
,
0
,
8
,
1
>
{
// The base class.
using
Base
=
Smem_tile_without_skews
<
Cta_tile
,
Cta_tile
::
K
,
Cta_tile
::
N
,
16
,
16
,
1
,
0
,
8
,
1
>
;
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The fragment.
using
Fragment
=
Fragment_b
<
fmha
::
Col
>
;
// The size of a single LDS in bytes.
enum
{
BYTES_PER_LDS
=
16
};
// Ctor.
inline
__device__
Smem_tile_v
(
void
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
// The row/col read by the thread.
int
read_row
,
read_col
;
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
read_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x0f
);
read_col
=
(
tidx
&
0x07
);
read_col
^=
(
tidx
&
0x10
)
/
16
;
// The shared memory offset.
this
->
smem_read_offset_
=
read_row
*
Base
::
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
// Load from shared memory.
inline
__device__
void
load
(
Fragment
(
&
b
)[
Mma_tile
::
MMAS_N
],
int
ki
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// Jump by 16 * #warps row.
int
row
=
ki
*
16
*
Cta_tile
::
WARPS_K
;
// Load the data using LDSM.MT88.2.
uint4
tmp
;
fmha
::
ldsmt
(
tmp
,
this
->
smem_
+
this
->
smem_read_offset_
+
row
*
Base
::
BYTES_PER_ROW
);
b
[
ni
].
reg
(
0
)
=
tmp
.
x
;
b
[
ni
].
reg
(
1
)
=
tmp
.
y
;
b
[
ni
].
reg
(
2
)
=
tmp
.
z
;
b
[
ni
].
reg
(
3
)
=
tmp
.
w
;
// Move the pointer for the next ni. I expect the compiler to not recompute those.
if
(
Mma_tile
::
MMAS_N
==
4
)
{
this
->
smem_read_offset_
^=
BYTES_PER_LDS
*
(
ni
%
2
==
0
?
2
:
6
);
}
else
{
assert
(
false
);
// Not implemented!
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_o
{
// The MMA tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The accumulators.
using
Accumulator
=
fmha
::
Fragment_accumulator
;
// The accumulators.
using
Data_type
=
typename
Accumulator
::
Data_type
;
// The size of each element.
enum
{
BYTES_PER_ELEMENT
=
sizeof
(
Data_type
)
};
// The size of each STS.
enum
{
BYTES_PER_STS
=
8
};
// The size of each row in shared memory.
enum
{
BYTES_PER_ROW
=
Cta_tile
::
N
*
Cta_tile
::
WARPS_K
*
BYTES_PER_ELEMENT
};
// The size of each LDS.
enum
{
BYTES_PER_LDS
=
16
};
enum
{
THREADS_PER_ROW
=
16
};
// The number of rows.
enum
{
ROWS
=
Cta_tile
::
M
};
// The number of "rows" to process per loop iteration (in the "epilogue").
enum
{
ROWS_PER_LOOP
=
ROWS
<=
64
?
ROWS
:
(
int
)
Mma_tile
::
M_PER_MMA_PER_CTA
};
// The number of outer loops.
enum
{
LOOPS
=
ROWS
/
ROWS_PER_LOOP
};
// Make sure it matches our expectations.
static_assert
(
LOOPS
==
1
||
LOOPS
==
(
int
)
Mma_tile
::
MMAS_M
,
""
);
// The number of rows loaded per LDS.
enum
{
ROWS_PER_LDS
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
// Do we have to guard against partial writes/reads.
enum
{
HAS_INCOMPLETE_LDS
=
ROWS_PER_LOOP
%
ROWS_PER_LDS
!=
0
};
// The total number of LDS per loop.
enum
{
LDS_PER_LOOP
=
fmha
::
Div_up
<
ROWS_PER_LOOP
,
ROWS_PER_LDS
>::
VALUE
};
// The amount of shared memory.
enum
{
BYTES_PER_TILE
=
ROWS_PER_LOOP
*
BYTES_PER_ROW
};
// The write pointer.
uint32_t
smem_write_
,
smem_read_
;
// Is the thread active for the last LDS of the series?
int
is_active_for_last_lds_
;
static_assert
(
BYTES_PER_ROW
==
64
*
4
*
Cta_tile
::
WARPS_K
);
static_assert
(
LOOPS
==
1
||
LOOPS
==
(
int
)
Mma_tile
::
MMAS_M
,
""
);
// Ctor.
inline
__device__
Smem_tile_o
(
void
*
smem
,
int
tidx
)
{
// Get a 32-bit value for the shared memory address.
uint32_t
smem_
=
__nvvm_get_smem_pointer
(
smem
);
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
1
&&
(
Cta_tile
::
WARPS_K
==
4
||
Cta_tile
::
WARPS_K
==
8
));
int
write_row
=
(
tidx
&
0x1c
)
/
4
;
int
write_col
=
(
tidx
);
// Assemble the write pointer.
smem_write_
=
smem_
+
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
// The element read by each thread.
int
read_row
=
tidx
/
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
// Take the XOR pattern into account for the column.
read_col
^=
2
*
(
read_row
&
0x7
);
// Assemble the read pointer.
this
->
smem_read_
=
smem_
+
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
// Is that thread active on the last LDS?
if
(
HAS_INCOMPLETE_LDS
)
{
this
->
is_active_for_last_lds_
=
read_row
+
(
LDS_PER_LOOP
-
1
)
*
ROWS_PER_LDS
<
Cta_tile
::
M
;
}
}
// Load the output fragments.
inline
__device__
void
load
(
uint4
(
&
out
)[
LDS_PER_LOOP
])
const
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
LDS_PER_LOOP
;
++
ii
)
{
// Load the elements before the reduction (split-K).
uint4
tmp
[
Cta_tile
::
WARPS_K
];
#pragma unroll
for
(
int
jj
=
0
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
int
imm
=
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
+
jj
*
Cta_tile
::
N
*
BYTES_PER_ELEMENT
;
if
(
!
HAS_INCOMPLETE_LDS
||
(
ii
<
LDS_PER_LOOP
-
1
||
this
->
is_active_for_last_lds_
)
)
{
fmha
::
lds
(
tmp
[
jj
],
this
->
smem_read_
+
imm
);
}
}
// Perform the reduction.
out
[
ii
]
=
tmp
[
0
];
#pragma unroll
for
(
int
jj
=
1
;
jj
<
Cta_tile
::
WARPS_K
;
++
jj
)
{
out
[
ii
]
=
fmha
::
fadd4
(
out
[
ii
],
tmp
[
jj
]);
}
}
}
// Store the accumulators.
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
Accumulator
(
&
acc
)[
M
][
N
],
int
mi
)
{
enum
{
M_PER_MMA
=
Mma_tile
::
M_PER_MMA_PER_CTA
};
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile
::
MMAS_N
;
++
ni
)
{
// The number of MMAs that are stored per loop iteration.
enum
{
MMAS_M_PER_LOOP
=
Mma_tile
::
MMAS_M
/
LOOPS
};
// Store 1st column of the different MMAs.
#pragma unroll
for
(
int
mj
=
0
;
mj
<
MMAS_M_PER_LOOP
;
++
mj
)
{
// Precompute the immediates to jump between rows.
int
row_0
=
(
mj
*
M_PER_MMA
+
0
)
*
BYTES_PER_ROW
;
int
row_1
=
(
mj
*
M_PER_MMA
+
8
)
*
BYTES_PER_ROW
;
uint2
tmp0
,
tmp1
;
tmp0
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
0
);
tmp0
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
1
);
tmp1
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
2
);
tmp1
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
3
);
// Store.
fmha
::
sts
(
this
->
smem_write_
+
row_0
,
tmp0
);
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
}
// Swizzle the write pointer using a XOR of 16B.
this
->
smem_write_
^=
32
;
// Store 2nd column of the different MMAs.
#pragma unroll
for
(
int
mj
=
0
;
mj
<
MMAS_M_PER_LOOP
;
++
mj
)
{
// Precompute the immediates to jump between rows.
int
row_0
=
(
mj
*
M_PER_MMA
+
0
)
*
BYTES_PER_ROW
;
int
row_1
=
(
mj
*
M_PER_MMA
+
8
)
*
BYTES_PER_ROW
;
uint2
tmp0
,
tmp1
;
tmp0
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
4
);
tmp0
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
5
);
tmp1
.
x
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
6
);
tmp1
.
y
=
acc
[
mi
*
MMAS_M_PER_LOOP
+
mj
][
ni
].
reg
(
7
);
// Store.
fmha
::
sts
(
this
->
smem_write_
+
row_0
,
tmp0
);
fmha
::
sts
(
this
->
smem_write_
+
row_1
,
tmp1
);
}
// Cancel the previous XOR of 1 + swizzle the write pointer using a XOR of 32B or 64B.
this
->
smem_write_
^=
(
ni
&
1
)
?
7
*
32
:
3
*
32
;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
>
struct
Smem_tile_mma
{
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
using
Fragment
=
fmha
::
Fragment_a
<
fmha
::
Col
>
;
enum
{
COLS
=
Cta_tile
::
N
};
enum
{
BYTES_PER_ELT
=
2
};
enum
{
BYTES_PER_STS
=
4
};
enum
{
BYTES_PER_ROW
=
COLS
*
BYTES_PER_ELT
};
// TODO
enum
{
BYTES_PER_TILE
=
Cta_tile
::
M
*
BYTES_PER_ROW
};
enum
{
WARPS_M
=
Cta_tile
::
WARPS_M
};
enum
{
WARPS_N
=
Cta_tile
::
WARPS_N
};
enum
{
WARPS_K
=
Cta_tile
::
WARPS_K
};
static_assert
(
WARPS_K
==
1
);
inline
__device__
Smem_tile_mma
(
char
*
smem
,
int
tidx
)
{
smem_
=
__nvvm_get_smem_pointer
(
smem
);
int
write_col
,
write_row
;
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
||
(
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
if
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
)
)
{
write_row
=
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0xe0
)
/
4
+
(
tidx
&
0x03
);
}
else
{
write_row
=
(
tidx
&
0xe0
)
/
2
+
(
tidx
&
0x1c
)
/
4
;
write_col
=
(
tidx
&
0x03
);
}
write_col
^=
(
write_row
&
0x07
)
*
4
;
write_offset_
=
write_row
*
BYTES_PER_ROW
+
write_col
*
BYTES_PER_STS
;
}
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
static_assert
(
COLS
==
Cta_tile
::
N
);
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
size_t
offset
=
write_offset_
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
fmha
::
sts
(
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
x
);
fmha
::
sts
(
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
z
);
offset
^=
4
*
BYTES_PER_STS
;
fmha
::
sts
(
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
y
);
fmha
::
sts
(
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
w
);
}
}
}
uint32_t
smem_
;
uint32_t
write_offset_
;
uint32_t
warp_m
;
uint32_t
warp_n
;
uint32_t
lane
;
};
template
<
typename
Cta_tile
,
typename
Base
=
Smem_tile_mma
<
Cta_tile
>
>
struct
Smem_tile_mma_transposed
:
public
Base
{
enum
{
BYTES_PER_LDS
=
16
};
enum
{
BYTES_PER_ROW
=
Base
::
BYTES_PER_ROW
};
enum
{
BYTES_PER_ELT
=
Base
::
BYTES_PER_ELT
};
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
));
using
Fragment
=
typename
Base
::
Fragment
;
inline
__device__
Smem_tile_mma_transposed
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
static_assert
(
WARPS_M
==
1
&&
(
WARPS_N
==
4
||
WARPS_N
==
8
));
int
read_row
,
read_col
;
read_row
=
(
tidx
&
0x0f
);
read_col
=
(
tidx
&
0xe0
)
/
16
+
(
tidx
&
0x1c
)
/
16
;
read_col
^=
(
read_row
&
0x07
);
read_offset_
=
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
template
<
int
M
,
int
N
>
inline
__device__
void
load
(
Fragment
(
&
frag
)[
M
][
N
])
{
static_assert
(
Base
::
COLS
==
Cta_tile
::
N
);
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
size_t
offset
=
read_offset_
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
+
ni
*
WARPS_N
*
16
*
BYTES_PER_ELT
;
uint4
dst
;
fmha
::
ldsmt
(
dst
,
this
->
smem_
+
offset
);
frag
[
mi
][
ni
].
reg
(
0
)
=
dst
.
x
;
frag
[
mi
][
ni
].
reg
(
1
)
=
dst
.
z
;
// Fragment A regs col major!
frag
[
mi
][
ni
].
reg
(
2
)
=
dst
.
y
;
frag
[
mi
][
ni
].
reg
(
3
)
=
dst
.
w
;
}
}
}
uint32_t
read_offset_
;
};
template
<
typename
Cta_tile
,
typename
Base
=
Smem_tile_mma
<
Cta_tile
>
>
struct
Smem_tile_mma_epilogue
:
public
Base
{
enum
{
BYTES_PER_LDS
=
16
};
enum
{
BYTES_PER_ROW
=
Base
::
BYTES_PER_ROW
};
enum
{
BYTES_PER_ELT
=
Base
::
BYTES_PER_ELT
};
enum
{
THREADS_PER_ROW
=
BYTES_PER_ROW
/
BYTES_PER_LDS
};
static_assert
(
THREADS_PER_ROW
*
BYTES_PER_LDS
==
BYTES_PER_ROW
);
enum
{
ROWS_PER_LDS
=
Cta_tile
::
THREADS_PER_CTA
/
THREADS_PER_ROW
};
enum
{
NUM_LDS
=
Cta_tile
::
M
/
ROWS_PER_LDS
};
static_assert
(
NUM_LDS
*
ROWS_PER_LDS
==
Cta_tile
::
M
);
enum
{
WARPS_M
=
Base
::
WARPS_M
};
enum
{
WARPS_N
=
Base
::
WARPS_N
};
static_assert
((
WARPS_M
==
4
||
WARPS_N
==
8
)
||
WARPS_N
==
1
);
using
Fragment
=
typename
Base
::
Fragment
;
inline
__device__
Smem_tile_mma_epilogue
(
char
*
smem
,
int
tidx
)
:
Base
(
smem
,
tidx
)
{
const
int
read_row
=
tidx
/
THREADS_PER_ROW
;
int
read_col
=
tidx
%
THREADS_PER_ROW
;
read_col
^=
(
read_row
&
0x07
);
read_offset_
=
read_row
*
BYTES_PER_ROW
+
read_col
*
BYTES_PER_LDS
;
}
inline
__device__
void
load
(
uint4
(
&
data
)[
NUM_LDS
])
{
for
(
int
ii
=
0
;
ii
<
NUM_LDS
;
ii
++
)
{
size_t
offset
=
read_offset_
+
ii
*
ROWS_PER_LDS
*
BYTES_PER_ROW
;
fmha
::
lds
(
data
[
ii
],
this
->
smem_
+
offset
);
}
}
template
<
int
M
,
int
N
>
inline
__device__
void
store
(
const
uint4
(
&
regs
)[
M
][
N
])
{
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
size_t
offset
=
(
this
->
write_offset_
^
(
ni
*
32
))
+
mi
*
WARPS_M
*
16
*
BYTES_PER_ROW
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
x
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
z
);
offset
^=
4
*
Base
::
BYTES_PER_STS
;
fmha
::
sts
(
this
->
smem_
+
offset
+
0
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
y
);
fmha
::
sts
(
this
->
smem_
+
offset
+
8
*
BYTES_PER_ROW
,
regs
[
mi
][
ni
].
w
);
}
}
}
uint32_t
read_offset_
;
};
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/softmax.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Sum_
{
enum
{
IS_SUM
=
1
};
static
inline
__device__
float
apply
(
float
x
,
float
y
)
{
return
x
+
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Max_
{
enum
{
IS_SUM
=
0
};
static
inline
__device__
float
apply
(
float
x
,
float
y
)
{
return
x
>
y
?
x
:
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
apply_exp_
(
float
x
,
float
max
)
{
return
__expf
(
x
-
max
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Softmax_base
{
// The Mma tile.
using
Mma_tile
=
fmha
::
Hmma_tile
<
Cta_tile
>
;
// The number of MMAs in M/N dimensions.
enum
{
MMAS_M
=
Mma_tile
::
MMAS_M
};
enum
{
MMAS_N
=
Mma_tile
::
MMAS_N
};
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
enum
{
GROUPS
=
fmha
::
Div_up
<
Cta_tile
::
WARPS_N
,
4
>::
VALUE
};
// The number of elements that we are going to store per row.
enum
{
ELEMENTS_PER_ROW
=
Cta_tile
::
WARPS_N
/
GROUPS
};
// The number of rows.
enum
{
ROWS
=
Cta_tile
::
M
*
GROUPS
};
// The total number of elements.
enum
{
ELEMENTS
=
ROWS
*
ELEMENTS_PER_ROW
};
// Ctor.
template
<
typename
Params
>
inline
__device__
Softmax_base
(
const
Params
&
params
,
void
*
smem
,
int
bidb
,
int
tidx
)
:
// packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
smem_
(
reinterpret_cast
<
float
*>
(
smem
)),
tidx_
(
tidx
)
{
// Move to the 1st mask loaded by the thread+ tidx;
// packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);
// Extract the position in the warp.
int
warp
=
tidx
/
Cta_tile
::
THREADS_PER_WARP
;
int
lane
=
tidx
%
Cta_tile
::
THREADS_PER_WARP
;
// Decompose the warp index into M and N.
int
warp_m
=
warp
%
Cta_tile
::
WARPS_M
;
int
warp_n
=
warp
/
Cta_tile
::
WARPS_M
;
// Decompose the warp-n index into group/position-inside-the-group.
int
warp_g
=
warp_n
/
ELEMENTS_PER_ROW
;
int
warp_i
=
warp_n
%
ELEMENTS_PER_ROW
;
// The location written by the threads.
int
write_row
=
warp_g
*
(
ROWS
/
GROUPS
)
+
warp_m
*
Mma_tile
::
M_PER_MMA
+
lane
/
4
;
int
write_col
=
warp_i
;
// Assemble the write pointer.
smem_write_
=
&
smem_
[
write_row
*
ELEMENTS_PER_ROW
+
write_col
];
// Assemble the read pointer.
smem_read_
=
&
smem_
[
warp_m
*
Mma_tile
::
M_PER_MMA
+
lane
/
4
];
}
template
<
typename
Mask
>
inline
__device__
void
apply_mask
(
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
++
ii
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
++
jj
)
{
if
(
!
mask
.
is_valid
(
mi
,
ni
,
ii
,
jj
)
)
{
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
-
INFINITY
;
}
}
}
}
}
}
// Apply the exp to all the elements.
inline
__device__
void
apply_exp
(
const
float
(
&
max
)[
MMAS_M
*
2
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
elt_
[
mi
][
ni
]
=
apply_exp_
(
elt_
[
mi
][
ni
],
max
[
mi
]);
}
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce_1x4
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if
(
Functor
::
IS_SUM
)
{
// Apply the summation inside the thread.
float
tmp
[
MMAS_M
*
2
][
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
tmp
[
mi
][
0
]
=
0.
f
;
tmp
[
mi
][
1
]
=
0.
f
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
0
];
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
1
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
2
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
3
];
}
dst
[
mi
]
=
tmp
[
mi
][
0
]
+
tmp
[
mi
][
1
];
}
}
else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
elt_
[
mi
][
0
];
#pragma unroll
for
(
int
ni
=
1
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
elt_
[
mi
][
ni
]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
__syncwarp
();
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
__syncwarp
();
}
// Store the different values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
if
(
tidx_
%
4
==
0
)
{
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
0
];
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
1
];
}
}
// Make sure the values are in shared memory.
__syncthreads
();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4
tmp
[
1
];
if
(
tidx_
<
Cta_tile
::
M
)
{
tmp
[
0
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
0
*
ELEMENTS
/
2
])[
tidx_
];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
y
);
tmp
[
0
].
z
=
Functor
::
apply
(
tmp
[
0
].
z
,
tmp
[
0
].
w
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
z
);
// Make sure we can write to shared memory.
__syncthreads
();
// Store the value back to shared memory.
if
(
tidx_
<
Cta_tile
::
M
)
{
smem_
[
tidx_
]
=
tmp
[
0
].
x
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// Finally read the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
dst
[
2
*
mi
+
0
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
];
dst
[
2
*
mi
+
1
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
];
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce_1x8
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
#if defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
if
(
Functor
::
IS_SUM
)
{
// Apply the summation inside the thread.
float
tmp
[
MMAS_M
*
2
][
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
tmp
[
mi
][
0
]
=
0.
f
;
tmp
[
mi
][
1
]
=
0.
f
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
0
];
tmp
[
mi
][
0
]
+=
elt_
[
mi
][
4
*
ni
+
1
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
2
];
tmp
[
mi
][
1
]
+=
elt_
[
mi
][
4
*
ni
+
3
];
}
dst
[
mi
]
=
tmp
[
mi
][
0
]
+
tmp
[
mi
][
1
];
}
}
else
#endif // defined(USE_SAME_SUM_ORDER_IN_SOFTMAX_AS_REF_CODE)
{
// Apply the functor for each row inside a thread.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
elt_
[
mi
][
0
];
#pragma unroll
for
(
int
ni
=
1
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
elt_
[
mi
][
ni
]);
}
}
}
// Apply the functor for each row inside each group of 4 threads.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
1
));
__syncwarp
();
dst
[
mi
]
=
Functor
::
apply
(
dst
[
mi
],
__shfl_xor_sync
(
uint32_t
(
-
1
),
dst
[
mi
],
2
));
__syncwarp
();
}
// Store the different values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
if
(
tidx_
%
4
==
0
)
{
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
0
];
smem_write_
[(
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
)
*
ELEMENTS_PER_ROW
]
=
dst
[
2
*
mi
+
1
];
}
}
// Make sure the values are in shared memory.
__syncthreads
();
// Load 8 values (one for each warp). The /8 corresponds to /(4*2) where 4 is from the
// float4.
float4
tmp
[
2
];
if
(
tidx_
<
Cta_tile
::
M
)
{
tmp
[
0
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
0
*
ELEMENTS
/
2
])[
tidx_
];
tmp
[
1
]
=
reinterpret_cast
<
const
float4
*>
(
&
smem_
[
1
*
ELEMENTS
/
2
])[
tidx_
];
}
// Compute the reduction of those 8 values in a binary-tree fashion.
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
y
);
tmp
[
0
].
z
=
Functor
::
apply
(
tmp
[
0
].
z
,
tmp
[
0
].
w
);
tmp
[
1
].
x
=
Functor
::
apply
(
tmp
[
1
].
x
,
tmp
[
1
].
y
);
tmp
[
1
].
z
=
Functor
::
apply
(
tmp
[
1
].
z
,
tmp
[
1
].
w
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
0
].
z
);
tmp
[
1
].
x
=
Functor
::
apply
(
tmp
[
1
].
x
,
tmp
[
1
].
z
);
tmp
[
0
].
x
=
Functor
::
apply
(
tmp
[
0
].
x
,
tmp
[
1
].
x
);
// Make sure we can write to shared memory.
__syncthreads
();
// Store the value back to shared memory.
if
(
tidx_
<
Cta_tile
::
M
)
{
smem_
[
tidx_
]
=
tmp
[
0
].
x
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// Finally read the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
dst
[
2
*
mi
+
0
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
0
];
dst
[
2
*
mi
+
1
]
=
smem_read_
[
mi
*
Mma_tile
::
M_PER_MMA_PER_CTA
+
8
];
}
}
// Do a CTA-wide reduction.
template
<
typename
Functor
>
inline
__device__
void
reduce
(
float
(
&
dst
)[
MMAS_M
*
2
])
{
static_assert
(
Cta_tile
::
WARPS_M
==
1
&&
(
Cta_tile
::
WARPS_N
==
4
||
Cta_tile
::
WARPS_N
==
8
));
if
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
4
)
{
reduce_1x4
<
Functor
>
(
dst
);
}
else
if
(
Cta_tile
::
WARPS_M
==
1
&&
Cta_tile
::
WARPS_N
==
8
)
{
reduce_1x8
<
Functor
>
(
dst
);
}
else
{
assert
(
false
);
}
// Make sure we are done reading from shared memory.
__syncthreads
();
}
// Scale all the elements.
inline
__device__
void
scale
(
const
float
(
&
sum
)[
MMAS_M
*
2
])
{
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
float
inv_sum
[
MMAS_M
*
2
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
inv_sum
[
mi
]
=
(
sum
[
mi
]
==
0.
f
||
sum
[
mi
]
!=
sum
[
mi
])
?
1.
f
:
1.
f
/
sum
[
mi
];
}
// Update the values.
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
*
2
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
elt_
[
mi
][
ni
]
*=
inv_sum
[
mi
];
}
}
}
// The pointer to the mask.
const
char
*
packed_mask_ptr_
;
// Shared memory for the CTA-wide reduction.
float
*
smem_
,
*
smem_write_
,
*
smem_read_
;
// The current thread index.
int
tidx_
;
// The elements.
float
elt_
[
MMAS_M
*
2
][
MMAS_N
*
4
];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Cta_tile
,
typename
Kernel_traits
>
struct
Softmax
:
public
Softmax_base
<
Cta_tile
,
Kernel_traits
>
{
// The base class.
using
Base
=
Softmax_base
<
Cta_tile
,
Kernel_traits
>
;
// The fragment.
using
Fragment_a
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
static_assert
(
Fragment_a
::
NUM_REGS
==
4
);
// The MMAs.
enum
{
MMAS_M
=
Base
::
MMAS_M
};
enum
{
MMAS_N
=
Base
::
MMAS_N
};
// The accumulators.
using
Accumulator
=
fmha
::
Fragment_accumulator
;
using
Accumulator_out
=
Fragment
<
uint16_t
,
8
>
;
static_assert
(
Accumulator_out
::
NUM_REGS
==
4
);
static_assert
(
std
::
is_same
<
Accumulator
::
Data_type
,
float
>::
value
);
// Ctor.
template
<
typename
Params
>
inline
__device__
Softmax
(
const
Params
&
params
,
void
*
smem
,
int
bidb
,
int
tidx
)
:
Base
(
params
,
smem
,
bidb
,
tidx
),
params_scale_bmm1_
(
params
.
scale_bmm1
)
{
}
// Store the tile after softmax.
template
<
typename
Gmem_tile
>
inline
__device__
void
store
(
Gmem_tile
&
gmem_tile
)
{
Accumulator_out
acc
[
MMAS_M
][
MMAS_N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// The elements.
float
tmp_00
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
];
float
tmp_01
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
];
float
tmp_02
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
];
float
tmp_03
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
];
float
tmp_10
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
];
float
tmp_11
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
];
float
tmp_12
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
];
float
tmp_13
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
];
// Transform to accumulators.
acc
[
mi
][
ni
].
reg
(
0
)
=
fmha
::
float2_to_half2
(
tmp_00
,
tmp_01
);
acc
[
mi
][
ni
].
reg
(
1
)
=
fmha
::
float2_to_half2
(
tmp_10
,
tmp_11
);
acc
[
mi
][
ni
].
reg
(
2
)
=
fmha
::
float2_to_half2
(
tmp_02
,
tmp_03
);
acc
[
mi
][
ni
].
reg
(
3
)
=
fmha
::
float2_to_half2
(
tmp_12
,
tmp_13
);
}
}
// Delegate to the gmem tile to store.
gmem_tile
.
store
(
acc
);
}
// Pack the data to a fragment for the next GEMM.
template
<
int
K
,
int
M
>
inline
__device__
void
pack
(
Fragment_a
(
&
dst
)[
K
][
M
])
const
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
++
mi
)
{
#pragma unroll
for
(
int
ki
=
0
;
ki
<
K
;
++
ki
)
{
// 1st row - 4 elements per row.
float
tmp_00
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
0
];
float
tmp_01
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
1
];
float
tmp_02
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
2
];
float
tmp_03
=
this
->
elt_
[
2
*
mi
+
0
][
4
*
ki
+
3
];
// 2nd row - 4 elements per row.
float
tmp_10
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
0
];
float
tmp_11
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
1
];
float
tmp_12
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
2
];
float
tmp_13
=
this
->
elt_
[
2
*
mi
+
1
][
4
*
ki
+
3
];
// Pack to 4 registers.
dst
[
ki
][
mi
].
reg
(
0
)
=
fmha
::
float2_to_half2
(
tmp_00
,
tmp_01
);
dst
[
ki
][
mi
].
reg
(
1
)
=
fmha
::
float2_to_half2
(
tmp_10
,
tmp_11
);
dst
[
ki
][
mi
].
reg
(
2
)
=
fmha
::
float2_to_half2
(
tmp_02
,
tmp_03
);
dst
[
ki
][
mi
].
reg
(
3
)
=
fmha
::
float2_to_half2
(
tmp_12
,
tmp_13
);
}
}
}
// Scale FP32 fragments
inline
__device__
void
unpack
(
const
Accumulator
(
&
acc
)[
MMAS_M
][
MMAS_N
])
{
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
this
->
params_scale_bmm1_
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
0
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
1
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
4
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
0
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
5
)
*
scalef
;
// 2nd row - 4 elements per row.
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
0
]
=
acc
[
mi
][
ni
].
elt
(
2
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
1
]
=
acc
[
mi
][
ni
].
elt
(
3
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
2
]
=
acc
[
mi
][
ni
].
elt
(
6
)
*
scalef
;
this
->
elt_
[
2
*
mi
+
1
][
4
*
ni
+
3
]
=
acc
[
mi
][
ni
].
elt
(
7
)
*
scalef
;
}
}
}
const
uint32_t
params_scale_bmm1_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha/utils.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
extern
"C"
__device__
uint32_t
__nvvm_get_smem_pointer
(
void
*
ptr
);
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Row
{};
struct
Col
{};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
bool
=
(
M
&
(
M
-
1
))
==
0
>
struct
Next_power_of_two
{
};
template
<
int
M
>
struct
Next_power_of_two
<
M
,
true
>
{
enum
{
VALUE
=
M
};
};
template
<
>
struct
Next_power_of_two
<
3
,
false
>
{
enum
{
VALUE
=
4
};
};
template
<
>
struct
Next_power_of_two
<
5
,
false
>
{
enum
{
VALUE
=
8
};
};
template
<
>
struct
Next_power_of_two
<
6
,
false
>
{
enum
{
VALUE
=
8
};
};
template
<
>
struct
Next_power_of_two
<
7
,
false
>
{
enum
{
VALUE
=
8
};
};
template
<
>
struct
Next_power_of_two
<
9
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
10
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
11
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
12
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
13
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
14
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
15
,
false
>
{
enum
{
VALUE
=
16
};
};
template
<
>
struct
Next_power_of_two
<
24
,
false
>
{
enum
{
VALUE
=
32
};
};
template
<
>
struct
Next_power_of_two
<
48
,
false
>
{
enum
{
VALUE
=
64
};
};
template
<
>
struct
Next_power_of_two
<
80
,
false
>
{
enum
{
VALUE
=
128
};
};
template
<
>
struct
Next_power_of_two
<
96
,
false
>
{
enum
{
VALUE
=
128
};
};
template
<
>
struct
Next_power_of_two
<
112
,
false
>
{
enum
{
VALUE
=
128
};
};
template
<
>
struct
Next_power_of_two
<
144
,
false
>
{
enum
{
VALUE
=
256
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
bool
=
(
N
&
(
N
-
1
))
==
0
>
struct
Prev_power_of_two
{
};
template
<
int
N
>
struct
Prev_power_of_two
<
N
,
true
>
{
enum
{
VALUE
=
N
};
};
template
<
>
struct
Prev_power_of_two
<
3
,
false
>
{
enum
{
VALUE
=
2
};
};
template
<
>
struct
Prev_power_of_two
<
5
,
false
>
{
enum
{
VALUE
=
4
};
};
template
<
>
struct
Prev_power_of_two
<
6
,
false
>
{
enum
{
VALUE
=
4
};
};
template
<
>
struct
Prev_power_of_two
<
7
,
false
>
{
enum
{
VALUE
=
4
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
int
N
>
struct
Div_up
{
enum
{
VALUE
=
(
M
+
N
-
1
)
/
N
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
A
,
int
B
>
struct
Max
{
enum
{
VALUE
=
A
>=
B
?
A
:
B
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
A
,
int
B
,
int
C
>
struct
Max_3
{
enum
{
VALUE
=
Max
<
Max
<
A
,
B
>::
VALUE
,
C
>::
VALUE
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
A
,
int
B
>
struct
Min
{
enum
{
VALUE
=
A
<=
B
?
A
:
B
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
SIZE_IN_BYTES
>
struct
Uint_from_size_in_bytes
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
1
>
{
using
Type
=
uint8_t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
2
>
{
using
Type
=
uint16_t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
4
>
{
using
Type
=
uint32_t
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
8
>
{
using
Type
=
uint2
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Uint_from_size_in_bytes
<
16
>
{
using
Type
=
uint4
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARPS_M
,
int
WARPS_N
,
int
WARPS_K
>
struct
Warp_masks
{
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Warp_masks
<
8
,
1
,
1
>
{
enum
{
M
=
0xe0
,
N
=
0x00
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
4
,
2
,
1
>
{
enum
{
M
=
0x60
,
N
=
0x80
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
4
,
1
,
2
>
{
enum
{
M
=
0x60
,
N
=
0x00
,
K
=
0x80
};
};
template
<
>
struct
Warp_masks
<
4
,
1
,
1
>
{
enum
{
M
=
0x60
,
N
=
0x00
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
2
,
4
,
1
>
{
enum
{
M
=
0x20
,
N
=
0xc0
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
2
,
2
,
2
>
{
enum
{
M
=
0x20
,
N
=
0x40
,
K
=
0x80
};
};
template
<
>
struct
Warp_masks
<
2
,
2
,
1
>
{
enum
{
M
=
0x20
,
N
=
0x40
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
2
,
1
,
2
>
{
enum
{
M
=
0x20
,
N
=
0x00
,
K
=
0x40
};
};
template
<
>
struct
Warp_masks
<
2
,
1
,
1
>
{
enum
{
M
=
0x20
,
N
=
0x00
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
8
,
1
>
{
enum
{
M
=
0x00
,
N
=
0xe0
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
4
,
2
>
{
enum
{
M
=
0x00
,
N
=
0x60
,
K
=
0x80
};
};
template
<
>
struct
Warp_masks
<
1
,
4
,
1
>
{
enum
{
M
=
0x00
,
N
=
0x60
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
2
,
2
>
{
enum
{
M
=
0x00
,
N
=
0x20
,
K
=
0x40
};
};
template
<
>
struct
Warp_masks
<
1
,
2
,
1
>
{
enum
{
M
=
0x00
,
N
=
0x20
,
K
=
0x00
};
};
template
<
>
struct
Warp_masks
<
1
,
1
,
4
>
{
enum
{
M
=
0x00
,
N
=
0x00
,
K
=
0x60
};
};
template
<
>
struct
Warp_masks
<
1
,
1
,
2
>
{
enum
{
M
=
0x00
,
N
=
0x00
,
K
=
0x20
};
};
template
<
>
struct
Warp_masks
<
1
,
1
,
1
>
{
enum
{
M
=
0x00
,
N
=
0x00
,
K
=
0x00
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
__host__
T
div_up
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
int
clz
(
int
x
)
{
for
(
int
i
=
31
;
i
>=
0
;
--
i
)
{
if
(
(
1
<<
i
)
&
x
)
{
return
31
-
i
;
}
}
return
32
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
int
find_log_2
(
int
x
,
bool
round_up
=
false
)
{
int
a
=
31
-
clz
(
x
);
if
(
round_up
)
{
a
+=
(
x
&
(
x
-
1
))
?
1
:
0
;
}
return
a
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hadd2
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"add.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hmin2
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"min.f16x2 %0, %1, %2;"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hmul2
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"mul.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hmul4
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
hmul2
(
a
.
x
,
b
.
x
);
c
.
y
=
hmul2
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hmul8
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
hmul2
(
a
.
x
,
b
.
x
);
c
.
y
=
hmul2
(
a
.
y
,
b
.
y
);
c
.
z
=
hmul2
(
a
.
z
,
b
.
z
);
c
.
w
=
hmul2
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hmul8
(
uint32_t
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
hmul2
(
a
,
b
.
x
);
c
.
y
=
hmul2
(
a
,
b
.
y
);
c
.
z
=
hmul2
(
a
,
b
.
z
);
c
.
w
=
hmul2
(
a
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hrelu2
(
uint32_t
x
,
uint32_t
lb
=
0
)
{
uint32_t
res
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"max.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
lb
));
#else
const
uint32_t
zero
=
0u
;
asm
volatile
(
\
"{
\n
"
\
"
\t
.reg .f16x2 sela;
\n
"
\
"
\t
set.gtu.u32.f16x2 sela, %1, %2;
\n
"
\
"
\t
and.b32 %0, sela, %1;
\n
"
"}
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
#endif
return
res
;
}
static
inline
__device__
uint32_t
habs2
(
uint32_t
x
)
{
uint32_t
res
;
asm
volatile
(
"abs.f16x2 %0, %1;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
));
return
res
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
template
<
typename
T
>
static
inline
__device__
T
clamp
(
T
x
,
T
lb
,
T
ub
)
{
return
x
<
lb
?
lb
:
(
x
>
ub
?
ub
:
x
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
clamp_to_zero
(
uint16_t
x
)
{
uint16_t
mask
;
asm
volatile
(
"set.gtu %0, %1, 0;"
:
"=h"
(
mask
)
:
"h"
(
x
));
return
mask
&
x
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
float_to_half
(
float
f
)
{
uint16_t
h
;
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;"
:
"=h"
(
h
)
:
"f"
(
f
));
return
h
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float2_to_half2
(
float
a
,
float
b
)
{
uint32_t
c
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cvt.rn.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"f"
(
b
),
"f"
(
a
));
#else
uint16_t
lo
=
float_to_half
(
a
);
uint16_t
hi
=
float_to_half
(
b
);
asm
volatile
(
"mov.b32 %0, {%1, %2};
\n
"
:
"=r"
(
c
)
:
"h"
(
lo
),
"h"
(
hi
));
#endif
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float_to_half2
(
float
a
)
{
return
float2_to_half2
(
a
,
a
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
float2_to_half2
(
const
float2
&
f
)
{
return
float2_to_half2
(
f
.
x
,
f
.
y
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
float4_to_half4
(
float
x
,
float
y
,
float
z
,
float
w
)
{
uint2
d
;
d
.
x
=
float2_to_half2
(
x
,
y
);
d
.
y
=
float2_to_half2
(
z
,
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hfma2
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hfma2_relu
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"fma.rn.f16x2.relu %0, %1, %2, %3;"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
#else
d
=
hrelu2
(
hfma2
(
a
,
b
,
c
));
#endif
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
h0_h0
(
uint32_t
x
)
{
uint32_t
y
;
asm
volatile
(
"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {lo, lo};}
\n
"
:
"=r"
(
y
)
:
"r"
(
x
));
return
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float
h0_to_float
(
uint32_t
h2
)
{
float
f
;
asm
volatile
(
"{
\n
"
\
".reg .f16 lo, hi;
\n
"
\
"mov.b32 {lo, hi}, %1;
\n
"
\
"cvt.f32.f16 %0, lo;
\n
"
\
"}
\n
"
:
"=f"
(
f
)
:
"r"
(
h2
));
return
f
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
h1_h1
(
uint32_t
x
)
{
uint32_t
y
;
asm
volatile
(
"{.reg .f16 lo, hi; mov.b32 {lo, hi}, %1; mov.b32 %0, {hi, hi};}
\n
"
:
"=r"
(
y
)
:
"r"
(
x
));
return
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
hadd
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
d
;
asm
volatile
(
"add.f16 %0, %1, %2;"
:
"=h"
(
d
)
:
"h"
(
a
),
"h"
(
b
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint32_t
hadd
(
uint32_t
a
,
uint32_t
b
)
{
return
hadd2
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hadd4
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
hadd2
(
a
.
x
,
b
.
x
);
c
.
y
=
hadd2
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint2
hadd
(
uint2
a
,
uint2
b
)
{
return
hadd4
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hadd8
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
hadd2
(
a
.
x
,
b
.
x
);
c
.
y
=
hadd2
(
a
.
y
,
b
.
y
);
c
.
z
=
hadd2
(
a
.
z
,
b
.
z
);
c
.
w
=
hadd2
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
fadd4
(
uint4
a
,
uint4
b
)
{
float4
c
;
c
.
x
=
reinterpret_cast
<
const
float
&>
(
a
.
x
)
+
reinterpret_cast
<
const
float
&>
(
b
.
x
);
c
.
y
=
reinterpret_cast
<
const
float
&>
(
a
.
y
)
+
reinterpret_cast
<
const
float
&>
(
b
.
y
);
c
.
z
=
reinterpret_cast
<
const
float
&>
(
a
.
z
)
+
reinterpret_cast
<
const
float
&>
(
b
.
z
);
c
.
w
=
reinterpret_cast
<
const
float
&>
(
a
.
w
)
+
reinterpret_cast
<
const
float
&>
(
b
.
w
);
return
reinterpret_cast
<
const
uint4
&>
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint4
hadd
(
uint4
a
,
uint4
b
)
{
return
hadd8
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float
half_to_float
(
uint16_t
h
)
{
float
f
;
asm
volatile
(
"cvt.f32.f16 %0, %1;
\n
"
:
"=f"
(
f
)
:
"h"
(
h
));
return
f
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float2
half2_to_float2
(
uint32_t
x
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;
\n
"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
x
));
return
make_float2
(
half_to_float
(
lo
),
half_to_float
(
hi
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
void
half2_to_float2
(
float
&
x
,
float
&
y
,
uint32_t
h
)
{
float2
tmp
=
half2_to_float2
(
h
);
x
=
tmp
.
x
;
y
=
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
hfma
(
uint16_t
a
,
uint16_t
b
,
uint16_t
c
)
{
uint16_t
d
;
asm
volatile
(
"fma.rn.f16 %0, %1, %2, %3;"
:
"=h"
(
d
)
:
"h"
(
a
),
"h"
(
b
),
"h"
(
c
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
uint16_t
hmul
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
d
;
asm
volatile
(
"mul.f16 %0, %1, %2;"
:
"=h"
(
d
)
:
"h"
(
a
),
"h"
(
b
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
static
inline
__device__
float
sigmoid
(
float
x
)
{
return
1.
f
/
(
1.
f
+
expf
(
-
x
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint32_t
&
dst
)
{
dst
=
0u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint2
&
dst
)
{
dst
=
make_uint2
(
0u
,
0u
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
clear
(
uint4
&
dst
)
{
dst
=
make_uint4
(
0u
,
0u
,
0u
,
0u
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// P R E D I C A T E P A C K I N G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
enum
{
BYTES_PER_REG
=
4
,
PREDS_PER_BYTE
=
4
,
PREDS_PER_REG
=
BYTES_PER_REG
*
PREDS_PER_BYTE
};
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// G E N E R I C P R E D I C A T E D L D G S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
,
typename
Functor
>
inline
__device__
void
load_
(
Functor
&
fct
,
const
uint32_t
(
&
preds
)[
M
])
{
// The number of complete bytes (where we use all the predicates in a byte).
enum
{
COMPLETE
=
N
/
PREDS_PER_BYTE
};
// Make sure we did allocate enough predicates.
static_assert
(
Div_up
<
COMPLETE
,
BYTES_PER_REG
>::
VALUE
<=
M
,
""
);
// The remainder.
enum
{
REMAINDER
=
N
-
COMPLETE
*
PREDS_PER_BYTE
};
// Make sure we got the math right and the remainder is between 0 and 3.
static_assert
(
REMAINDER
>=
0
&&
REMAINDER
<=
3
,
""
);
// The mask to extract the predicates.
enum
{
COMPLETE_MASK
=
(
1
<<
PREDS_PER_BYTE
)
-
1
};
// Clear the fetch registers.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
fct
.
clear
(
ii
);
}
// Run complete steps.
bool
p
[
PREDS_PER_BYTE
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
COMPLETE
;
++
ii
)
{
// The predicate.
uint32_t
reg
=
preds
[
ii
/
BYTES_PER_REG
];
// Extract the predicates.
#pragma unroll
for
(
int
jj
=
0
;
jj
<
PREDS_PER_BYTE
;
++
jj
)
{
uint32_t
mask
=
1u
<<
(
ii
%
BYTES_PER_REG
*
8
+
jj
);
p
[
jj
]
=
(
reg
&
mask
)
!=
0u
;
}
// Issue the loads.
#pragma unroll
for
(
int
jj
=
0
;
jj
<
PREDS_PER_BYTE
;
++
jj
)
{
fct
.
load
(
ii
*
PREDS_PER_BYTE
+
jj
,
p
[
jj
]);
}
}
// Skip the rest of the code if we do not have a remainder.
if
(
REMAINDER
>
0
)
{
// The mask to extract the predicates.
enum
{
REMAINDER_MASK
=
(
1
<<
REMAINDER
)
-
1
};
// The predicate register.
uint32_t
reg
=
preds
[
COMPLETE
/
BYTES_PER_REG
];
// Extract the predicates.
#pragma unroll
for
(
int
jj
=
0
;
jj
<
PREDS_PER_BYTE
;
++
jj
)
{
uint32_t
mask
=
1u
<<
(
COMPLETE
%
BYTES_PER_REG
*
8
+
jj
);
p
[
jj
]
=
(
reg
&
mask
)
!=
0u
;
}
// Issue the loads.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
REMAINDER
;
++
ii
)
{
fct
.
load
(
COMPLETE
*
PREDS_PER_BYTE
+
ii
,
p
[
ii
]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
M
,
typename
Functor
>
inline
__device__
void
load_
(
Functor
&
fct
,
uint32_t
preds
)
{
uint32_t
tmp
[
1
]
=
{
preds
};
load_
<
M
>
(
fct
,
tmp
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint8_t
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint8_t
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint16_t
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint16_t
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint32_t
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint32_t
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint2
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint2
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldg
(
uint4
&
dst
,
const
void
*
ptr
)
{
dst
=
*
reinterpret_cast
<
const
uint4
*>
(
ptr
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type
,
int
N
>
struct
Ldg_functor
{
// Ctor.
inline
__device__
Ldg_functor
(
Data_type
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
])
:
fetch_
(
fetch
),
ptrs_
(
ptrs
)
{
}
// Clear the element.
inline
__device__
void
clear
(
int
ii
)
{
fmha
::
clear
(
fetch_
[
ii
]);
}
// Trigger the loads.
inline
__device__
void
load
(
int
ii
,
bool
p
)
{
if
(
p
)
{
ldg
(
fetch_
[
ii
],
ptrs_
[
ii
]);
}
}
// The fetch registers.
Data_type
(
&
fetch_
)[
N
];
// The pointers.
const
void
*
(
&
ptrs_
)[
N
];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type
,
int
N
,
int
M
>
inline
__device__
void
ldg_
(
Data_type
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
Ldg_functor
<
Data_type
,
N
>
fct
(
fetch
,
ptrs
);
load_
<
N
>
(
fct
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint8_t
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint8_t
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint16_t
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint16_t
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint32_t
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint32_t
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint2
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint2
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
,
int
M
>
inline
__device__
void
ldg
(
uint4
(
&
fetch
)[
N
],
const
void
*
(
&
ptrs
)[
N
],
uint32_t
(
&
preds
)[
M
])
{
ldg_
<
uint4
,
N
>
(
fetch
,
ptrs
,
preds
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint16_t
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.b16 %0, [%1];
\n
"
:
"=h"
(
dst
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.b32 %0, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint2
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.v2.b32 {%0, %1}, [%2];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
lds
(
uint4
&
dst
,
uint32_t
ptr
)
{
asm
volatile
(
"ld.shared.v4.b32 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
dst
.
x
)
,
"=r"
(
dst
.
y
)
,
"=r"
(
dst
.
z
)
,
"=r"
(
dst
.
w
)
:
"r"
(
ptr
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D S M
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsmt
(
uint32_t
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];
\n
"
:
"=r"
(
dst
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint2
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsmt
(
uint2
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsm
(
uint4
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
),
"=r"
(
dst
.
z
),
"=r"
(
dst
.
w
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
ldsmt
(
uint4
&
dst
,
uint32_t
ptr
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 730
asm
volatile
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(
dst
.
x
),
"=r"
(
dst
.
y
),
"=r"
(
dst
.
z
),
"=r"
(
dst
.
w
)
:
"r"
(
ptr
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint8_t
val
)
{
*
reinterpret_cast
<
uint8_t
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint16_t
val
)
{
*
reinterpret_cast
<
uint16_t
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint32_t
val
)
{
*
reinterpret_cast
<
uint32_t
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint2
val
)
{
*
reinterpret_cast
<
uint2
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
stg
(
void
*
ptr
,
uint4
val
)
{
*
reinterpret_cast
<
uint4
*>
(
ptr
)
=
val
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T S
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint16_t
val
)
{
asm
volatile
(
"st.shared.b16 [%0], %1;
\n
"
:
:
"r"
(
ptr
),
"h"
(
val
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint32_t
val
)
{
asm
volatile
(
"st.shared.b32 [%0], %1;
\n
"
:
:
"r"
(
ptr
),
"r"
(
val
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint2
val
)
{
asm
volatile
(
"st.shared.v2.b32 [%0], {%1, %2};
\n
"
:
:
"r"
(
ptr
)
,
"r"
(
val
.
x
)
,
"r"
(
val
.
y
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
sts
(
uint32_t
ptr
,
uint4
val
)
{
asm
volatile
(
"st.shared.v4.b32 [%0], {%1, %2, %3, %4};
\n
"
:
:
"r"
(
ptr
)
,
"r"
(
val
.
x
)
,
"r"
(
val
.
y
)
,
"r"
(
val
.
z
)
,
"r"
(
val
.
w
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Data_type
,
int
N
>
inline
__device__
void
sts_
(
uint32_t
(
&
ptrs
)[
N
],
const
Data_type
(
&
data
)[
N
])
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
N
;
++
ii
)
{
sts
(
ptrs
[
ii
],
data
[
ii
]);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint16_t
(
&
data
)[
N
])
{
sts_
<
uint16_t
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint32_t
(
&
data
)[
N
])
{
sts_
<
uint32_t
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint2
(
&
data
)[
N
])
{
sts_
<
uint2
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
void
sts
(
uint32_t
(
&
ptrs
)[
N
],
const
uint4
(
&
data
)[
N
])
{
sts_
<
uint4
,
N
>
(
ptrs
,
data
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_128_64_kernel.sm80.cu
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_128_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
128
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
fmha_dgrad_fp16_128_64_sm80_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
fmha_dgrad_fp16_128_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_256_64_kernel.sm80.cu
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_256_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
256
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
fmha_dgrad_fp16_256_64_sm80_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
fmha_dgrad_fp16_256_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_384_64_kernel.sm80.cu
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_384_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
384
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
fmha_dgrad_fp16_384_64_sm80_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
fmha_dgrad_fp16_384_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_dgrad_kernel_1xN_reload.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_dgrad_fp16_512_64_sm80_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
compute_dv_1xN
<
Kernel_traits
>
(
params
);
fmha
::
compute_dq_dk_1xN
<
Kernel_traits
>
(
params
);
}
void
run_fmha_dgrad_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
cudaStream_t
stream
)
{
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
using
Smem_tile_s
=
fmha
::
Smem_tile_mma_transposed
<
Kernel_traits
::
Cta_tile_p
>
;
constexpr
int
smem_size_s
=
Smem_tile_s
::
BYTES_PER_TILE
;
static_assert
(
smem_size_s
==
16
*
512
*
2
);
static_assert
(
smem_size_o
==
16
*
64
*
4
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
);
constexpr
int
smem_size_dv
=
smem_size_s
+
2
*
smem_size_q
+
smem_size_v
+
smem_size_softmax
;
constexpr
int
smem_size_dq_dk
=
smem_size_s
+
smem_size_o
+
smem_size_q
+
smem_size_v
;
constexpr
int
smem_size
=
std
::
max
(
smem_size_dv
,
smem_size_dq_dk
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
fmha_dgrad_fp16_512_64_sm80_kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
fmha_dgrad_fp16_512_64_sm80_kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_dgrad_kernel_1xN_reload.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dv_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dv
=
fmha
::
Cta_tile_extd
<
Cta_tile_p
::
N
,
Cta_tile_p
::
K
,
Cta_tile_p
::
M
,
Cta_tile_p
::
WARPS_N
,
1
,
Cta_tile_p
::
WARPS_M
>
;
static_assert
(
Cta_tile_dv
::
M
==
512
||
Cta_tile_dv
::
M
==
384
||
Cta_tile_dv
::
M
==
256
||
Cta_tile_dv
::
M
==
128
);
static_assert
(
Cta_tile_dv
::
N
==
64
);
static_assert
(
Cta_tile_dv
::
K
==
16
);
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dv
=
fmha
::
Hmma_tile
<
Cta_tile_dv
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
// using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
using
Smem_tile_q
=
fmha
::
Smem_tile_a
<
Cta_tile_p
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
2
>
;
// The shared memory tile to reload Q as fragment b.
using
Smem_tile_qt
=
fmha
::
Smem_tile_b
<
Cta_tile_dv
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
2
>
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store dV.
using
Gmem_tile_dv
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle dV.
using
Smem_tile_dv
=
fmha
::
Smem_tile_mma_epilogue
<
Cta_tile_dv
>
;
static_assert
(
Smem_tile_dv
::
NUM_LDS
==
Gmem_tile_dv
::
LDGS
);
static_assert
(
Smem_tile_dv
::
THREADS_PER_ROW
==
Gmem_tile_dv
::
THREADS_PER_ROW
);
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Smem_tile_st
=
typename
Kernel_traits
::
Smem_tile_st
;
using
Gmem_tile_do
=
typename
Kernel_traits
::
Gmem_tile_do
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_do
gmem_q
(
params
,
binfo
,
tidx
);
// treating dout as Q
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
2
,
binfo
,
tidx
);
// treating V as K
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_q
.
load
(
frag_q
[
0
],
0
);
typename
Smem_tile_qt
::
Fragment
frag_qt
[
2
][
Mma_tile_dv
::
MMAS_N
];
static_assert
(
Smem_tile_qt
::
Fragment
::
NUM_REGS
==
4
);
static_assert
(
Mma_tile_dv
::
MMAS_K
==
1
);
smem_qt
.
load
(
frag_qt
[
0
],
0
);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
2
][
Mma_tile_p
::
MMAS_N
];
smem_k
.
load
(
frag_k
[
0
],
0
);
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_st
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
bidb
,
tidx
);
enum
{
THREADS_PER_ROW
=
32
};
enum
{
M
=
Mma_tile_p
::
MMAS_M
};
enum
{
N
=
Mma_tile_p
::
MMAS_N
};
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dv
[
Mma_tile_dv
::
MMAS_M
][
Mma_tile_dv
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dv
::
WARPS_K
>::
apply
(
acc_dv
);
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Load S
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Store s * dmask to smem for transpose
smem_s
.
store
(
s_regs
);
// Declare the accumulators for the 1st gemm.
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
float
s_mat
[
2
*
M
][
4
*
N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint4
&
dst
=
s_regs
[
mi
][
ni
];
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
0
],
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
1
],
dst
.
x
);
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
2
],
s_mat
[
2
*
mi
+
0
][
4
*
ni
+
3
],
dst
.
y
);
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
0
],
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
1
],
dst
.
z
);
fmha
::
half2_to_float2
(
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
2
],
s_mat
[
2
*
mi
+
1
][
4
*
ni
+
3
],
dst
.
w
);
}
}
float
d_s
[
2
*
M
][
4
*
N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
const
float
s_dmask
=
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
];
const
bool
drop
=
reinterpret_cast
<
const
uint32_t
&>
(
s_dmask
)
&
0x80000000
;
d_s
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
drop
?
0.
f
:
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*
params
.
rp_dropout
;
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
d_s
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
*
fabsf
(
s_dmask
);
}
}
}
}
float
p_sum
[
2
*
M
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
#pragma unroll
for
(
int
jj
=
0
;
jj
<
4
;
jj
++
)
{
const
float
scalef
=
reinterpret_cast
<
const
float
&>
(
params
.
scale_softmax
);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
=
(
d_s
[
2
*
mi
+
ii
][
4
*
ni
+
jj
]
-
p_sum
[
2
*
mi
+
ii
])
*
fabsf
(
s_mat
[
2
*
mi
+
ii
][
4
*
ni
+
jj
])
*
scalef
;
}
}
}
}
// Trigger the load for the next Q values. We're using double buffering, so reading qt is safe
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
typename
Smem_tile_st
::
Fragment
frag_s
[
Mma_tile_dv
::
MMAS_K
][
Mma_tile_dv
::
MMAS_M
];
smem_s
.
load
(
frag_s
);
for
(
int
ki
=
0
;
ki
<
Mma_tile_dv
::
MMAS_K
;
ki
++
)
{
for
(
int
mi
=
0
;
mi
<
Mma_tile_dv
::
MMAS_M
;
mi
++
)
{
for
(
int
ii
=
0
;
ii
<
Smem_tile_st
::
Fragment
::
NUM_REGS
;
ii
++
)
{
frag_s
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_s
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_s
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_s
[
ki
][
mi
].
reg
(
ii
));
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dv
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dv
::
MMAS_K
;
fmha
::
gemm
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Commit the values for Q into shared memory.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure we are reading from the correct buffer.
smem_q
.
move_to_next_read_buffer
();
smem_qt
.
move_to_next_read_buffer
();
// Make sure the data is in shared memory.
__syncthreads
();
// Trigger the loads for the values of Q for the next iteration.
smem_q
.
load
(
frag_q
[
0
],
0
);
smem_k
.
load
(
frag_k
[
0
],
0
);
smem_qt
.
load
(
frag_qt
[
0
],
0
);
}
// Outer loop over the sequence length.
// Epilogue swizzle for dV
Smem_tile_dv
smem_dv
(
&
smem_
[
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
uint4
dv
[
Mma_tile_dv
::
MMAS_M
][
Mma_tile_dv
::
MMAS_N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_dv
::
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_dv
::
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
float
tmp00
=
acc_dv
[
mi
][
ni
].
elt
(
0
);
float
tmp01
=
acc_dv
[
mi
][
ni
].
elt
(
1
);
float
tmp02
=
acc_dv
[
mi
][
ni
].
elt
(
4
);
float
tmp03
=
acc_dv
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
float
tmp10
=
acc_dv
[
mi
][
ni
].
elt
(
2
);
float
tmp11
=
acc_dv
[
mi
][
ni
].
elt
(
3
);
float
tmp12
=
acc_dv
[
mi
][
ni
].
elt
(
6
);
float
tmp13
=
acc_dv
[
mi
][
ni
].
elt
(
7
);
dv
[
mi
][
ni
].
x
=
fmha
::
float2_to_half2
(
tmp00
,
tmp01
);
dv
[
mi
][
ni
].
y
=
fmha
::
float2_to_half2
(
tmp02
,
tmp03
);
dv
[
mi
][
ni
].
z
=
fmha
::
float2_to_half2
(
tmp10
,
tmp11
);
dv
[
mi
][
ni
].
w
=
fmha
::
float2_to_half2
(
tmp12
,
tmp13
);
}
}
smem_dv
.
store
(
dv
);
__syncthreads
();
uint4
dv_out
[
Smem_tile_dv
::
NUM_LDS
];
smem_dv
.
load
(
dv_out
);
Qkv_params
dv_params
;
dv_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dv_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
Gmem_tile_dv
gmem_dv
(
dv_params
,
2
,
binfo
,
tidx
);
gmem_dv
.
store
(
dv_out
);
}
template
<
typename
Kernel_traits
,
typename
Params
>
inline
__device__
void
compute_dq_dk_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_dk
=
fmha
::
Cta_tile_extd
<
Cta_tile_p
::
N
,
Cta_tile_p
::
K
,
Cta_tile_p
::
M
,
Cta_tile_p
::
WARPS_N
,
1
,
Cta_tile_p
::
WARPS_M
>
;
static_assert
(
Cta_tile_dk
::
M
==
512
||
Cta_tile_dk
::
M
==
384
||
Cta_tile_dk
::
M
==
256
||
Cta_tile_dk
::
M
==
128
);
static_assert
(
Cta_tile_dk
::
N
==
64
);
static_assert
(
Cta_tile_dk
::
K
==
16
);
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_dk
=
fmha
::
Hmma_tile
<
Cta_tile_dk
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_v
;
// K is used like V in fprop
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
// using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
using
Gmem_tile_o
=
fmha
::
Gmem_tile_dq
<
Cta_tile_o
>
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
// The global memory tile to store dK.
using
Gmem_tile_dk
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle dK.
using
Smem_tile_dk
=
fmha
::
Smem_tile_mma_epilogue
<
Cta_tile_dk
>
;
static_assert
(
Smem_tile_dk
::
NUM_LDS
==
Gmem_tile_dk
::
LDGS
);
static_assert
(
Smem_tile_dk
::
THREADS_PER_ROW
==
Gmem_tile_dk
::
THREADS_PER_ROW
);
// The shared memory tile to reload Q transposed.
using
Smem_tile_qt
=
fmha
::
Smem_tile_b
<
Cta_tile_dk
,
fmha
::
Row
,
Gmem_tile_q
::
BYTES_PER_LDG
,
1
>
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
using
Smem_tile_st
=
typename
Kernel_traits
::
Smem_tile_st
;
enum
{
M
=
Mma_tile_p
::
MMAS_M
};
enum
{
N
=
Mma_tile_p
::
MMAS_N
};
static_assert
(
M
==
Mma_tile_o
::
MMAS_M
);
static_assert
(
N
==
Mma_tile_o
::
MMAS_K
);
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
Smem_tile_qt
smem_qt
(
&
smem_
[
0
],
tidx
);
Smem_tile_st
smem_s
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Load dP
uint4
s_regs
[
M
][
N
];
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
move
();
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Make sure the data is in shared memory.
__syncthreads
();
typename
Smem_tile_qt
::
Fragment
frag_qt
[
2
][
Mma_tile_dk
::
MMAS_N
];
smem_qt
.
load
(
frag_qt
[
0
],
0
);
typename
Smem_tile_k
::
Fragment
frag_k
[
2
][
Mma_tile_o
::
MMAS_N
];
smem_k
.
load
(
frag_k
[
0
],
0
);
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
enum
{
THREADS_PER_ROW
=
32
};
// Declare the accumulators for the 2nd gemm.
fmha
::
Fragment_accumulator
acc_dk
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_dk
::
WARPS_K
>::
apply
(
acc_dk
);
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Pack dP as Fragment_a
fmha
::
Fragment_a
<
fmha
::
Row
>
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M
;
mi
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N
;
ni
++
)
{
uint4
&
dst
=
s_regs
[
mi
][
ni
];
frag_p
[
ni
][
mi
].
reg
(
0
)
=
dst
.
x
;
// row 0, cols 0,1
frag_p
[
ni
][
mi
].
reg
(
1
)
=
dst
.
z
;
// row 8, cols 0,1
frag_p
[
ni
][
mi
].
reg
(
2
)
=
dst
.
y
;
// row 0, cols 8,9
frag_p
[
ni
][
mi
].
reg
(
3
)
=
dst
.
w
;
// row 8, cols 8,9
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T. dQ = dP x dK
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
-
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_o
::
MMAS_K
;
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
-
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
// Store dP to smem for transpose
smem_s
.
store
(
s_regs
);
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
// Load next part of S
gmem_s
.
load
(
s_regs
,
mask
);
gmem_s
.
move
();
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
typename
Smem_tile_st
::
Fragment
frag_s
[
Mma_tile_dk
::
MMAS_K
][
Mma_tile_dk
::
MMAS_M
];
smem_s
.
load
(
frag_s
);
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_dk
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_dk
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_dk
::
MMAS_K
;
fmha
::
gemm
(
acc_dk
,
frag_s
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
// Commit the values for Q into shared memory.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Trigger the loads for the values of Q for the next iteration.
smem_qt
.
load
(
frag_qt
[
0
],
0
);
smem_k
.
load
(
frag_k
[
0
],
0
);
}
// Outer loop over the sequence length.
// Epilogue swizzle for dK
Smem_tile_dk
smem_dk
(
&
smem_
[
0
],
tidx
);
uint4
dk
[
Mma_tile_dk
::
MMAS_M
][
Mma_tile_dk
::
MMAS_N
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_dk
::
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_dk
::
MMAS_N
;
++
ni
)
{
// 1st row - 4 elements per row.
float
tmp00
=
acc_dk
[
mi
][
ni
].
elt
(
0
);
float
tmp01
=
acc_dk
[
mi
][
ni
].
elt
(
1
);
float
tmp02
=
acc_dk
[
mi
][
ni
].
elt
(
4
);
float
tmp03
=
acc_dk
[
mi
][
ni
].
elt
(
5
);
// 2nd row - 4 elements per row.
float
tmp10
=
acc_dk
[
mi
][
ni
].
elt
(
2
);
float
tmp11
=
acc_dk
[
mi
][
ni
].
elt
(
3
);
float
tmp12
=
acc_dk
[
mi
][
ni
].
elt
(
6
);
float
tmp13
=
acc_dk
[
mi
][
ni
].
elt
(
7
);
dk
[
mi
][
ni
].
x
=
fmha
::
float2_to_half2
(
tmp00
,
tmp01
);
dk
[
mi
][
ni
].
y
=
fmha
::
float2_to_half2
(
tmp02
,
tmp03
);
dk
[
mi
][
ni
].
z
=
fmha
::
float2_to_half2
(
tmp10
,
tmp11
);
dk
[
mi
][
ni
].
w
=
fmha
::
float2_to_half2
(
tmp12
,
tmp13
);
}
}
smem_dk
.
store
(
dk
);
__syncthreads
();
uint4
dk_out
[
Smem_tile_dk
::
NUM_LDS
];
smem_dk
.
load
(
dk_out
);
Qkv_params
dk_params
;
dk_params
.
qkv_ptr
=
params
.
dqkv_ptr
;
dk_params
.
qkv_stride_in_bytes
=
params
.
qkv_stride_in_bytes
;
Gmem_tile_dk
gmem_dk
(
dk_params
,
1
,
binfo
,
tidx
);
gmem_dk
.
store
(
dk_out
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_128_64_kernel.sm80.cu
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_128_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
}
extern
"C"
__global__
void
fmha_fprop_fp16_128_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
void
run_fmha_fp16_128_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_128_64_sm80_train_kernel
:
&
fmha_fprop_fp16_128_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_256_64_kernel.sm80.cu
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_256_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
}
extern
"C"
__global__
void
fmha_fprop_fp16_256_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
void
run_fmha_fp16_256_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_256_64_sm80_train_kernel
:
&
fmha_fprop_fp16_256_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_384_64_kernel.sm80.cu
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN_reload_v.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
384
,
64
,
16
,
1
,
4
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_384_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
}
extern
"C"
__global__
void
fmha_fprop_fp16_384_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
void
run_fmha_fp16_384_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_384_64_sm80_train_kernel
:
&
fmha_fprop_fp16_384_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_v
+
smem_size_o
+
smem_size_softmax
;
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_fprop_fp16_512_64_kernel.sm80.cu
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include "fmha.h"
#include "fmha_fprop_kernel_1xN.h"
using
Kernel_traits
=
FMHA_kernel_traits
<
512
,
64
,
16
,
1
,
8
,
0x08u
>
;
extern
"C"
__global__
void
fmha_fprop_fp16_512_64_sm80_train_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
true
>
(
params
);
}
extern
"C"
__global__
void
fmha_fprop_fp16_512_64_sm80_predict_kernel
(
Fused_multihead_attention_fprop_params
params
)
{
fmha
::
device_1xN
<
Kernel_traits
,
false
>
(
params
);
}
void
run_fmha_fp16_512_64_sm80
(
const
Fused_multihead_attention_fprop_params
&
params
,
bool
is_training
,
cudaStream_t
stream
)
{
auto
kernel
=
is_training
?
&
fmha_fprop_fp16_512_64_sm80_train_kernel
:
&
fmha_fprop_fp16_512_64_sm80_predict_kernel
;
constexpr
int
smem_size_softmax
=
Kernel_traits
::
Cta_tile_p
::
M
*
Kernel_traits
::
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
);
constexpr
int
smem_size_q
=
Kernel_traits
::
Smem_tile_q
::
BYTES_PER_TILE
;
constexpr
int
smem_size_v
=
Kernel_traits
::
Smem_tile_v
::
BYTES_PER_TILE
;
constexpr
int
smem_size_o
=
Kernel_traits
::
Smem_tile_o
::
BYTES_PER_TILE
;
constexpr
int
smem_size
=
smem_size_q
+
std
::
max
(
smem_size_v
,
smem_size_o
+
smem_size_softmax
);
if
(
smem_size
>=
48
*
1024
)
{
FMHA_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
dim3
grid
(
params
.
h
,
params
.
b
);
kernel
<<<
grid
,
Kernel_traits
::
THREADS
,
smem_size
,
stream
>>>
(
params
);
}
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
2
][
Mma_tile_p
::
MMAS_M
];
smem_q
.
load
(
frag_q
[
0
],
0
);
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
// Make sure the data is in shared memory.
__syncthreads
();
}
// Load the fragments for V. We keep the data in registers during the entire kernel.
typename
Smem_tile_v
::
Fragment
frag_v
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
smem_v
.
load
(
frag_v
[
ki
],
ki
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
enum
{
THREADS_PER_ROW
=
32
};
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for
(
int
ki
=
1
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Do the final stage of math.
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm
(
acc_p
,
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
// Store the P matrix.
#if defined(STORE_P)
gmem_p
.
store
(
acc_p
);
#endif
// Load the mask for that iteration.
mask
.
load
(
outer
);
// Convert from the accumulator type to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
&&
loop
==
0
)
{
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads
();
}
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
// Trigger the load for the next Q values.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
// Do this part of O = P^T * V^T.
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Make sure the data was read from shared memory.
if
(
ii
<
Gmem_tile_o
::
LOOPS
-
1
)
{
__syncthreads
();
}
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Trigger the loads for the values of Q for the next iteration.
smem_q
.
load
(
frag_q
[
0
],
0
);
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
apex/contrib/csrc/fmha/src/fmha_fprop_kernel_1xN_reload_v.h
0 → 100644
View file @
5c9b21d8
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include "fmha_kernel.h"
#include <fmha/kernel_traits.h>
#include <fmha/gemm.h>
namespace
fmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
typename
Params
>
inline
__device__
void
device_1xN
(
const
Params
&
params
)
{
// The description of the CTA tile for the 1st batched GEMM.
using
Cta_tile_p
=
typename
Kernel_traits
::
Cta_tile_p
;
// The description of the CTA tile for the 2nd batched GEMM.
using
Cta_tile_o
=
typename
Kernel_traits
::
Cta_tile_o
;
// The MMA tile for the 1st GEMM.
using
Mma_tile_p
=
fmha
::
Hmma_tile
<
Cta_tile_p
>
;
// The MMA tile for the 2nd GEMM.
using
Mma_tile_o
=
fmha
::
Hmma_tile
<
Cta_tile_o
>
;
// The global memory tile to load Q.
using
Gmem_tile_q
=
typename
Kernel_traits
::
Gmem_tile_q
;
// The shared memory tile to swizzle Q.
using
Smem_tile_q
=
typename
Kernel_traits
::
Smem_tile_q
;
// The global memory tile to load K.
using
Gmem_tile_k
=
typename
Kernel_traits
::
Gmem_tile_k
;
// The shared memory tile to swizzle K.
using
Smem_tile_k
=
typename
Kernel_traits
::
Smem_tile_k
;
// The global memory tile to load V.
using
Gmem_tile_v
=
typename
Kernel_traits
::
Gmem_tile_v
;
// The shared memory tile to swizzle V.
using
Smem_tile_v
=
typename
Kernel_traits
::
Smem_tile_v
;
// The global memory tile to store O.
using
Gmem_tile_o
=
typename
Kernel_traits
::
Gmem_tile_o
;
// The shared memory tile to swizzle O.
using
Smem_tile_o
=
typename
Kernel_traits
::
Smem_tile_o
;
using
Gmem_tile_s
=
typename
Kernel_traits
::
Gmem_tile_s
;
// Shared memory.
extern
__shared__
char
smem_
[];
// The block index for the batch.
const
int
bidb
=
blockIdx
.
y
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
x
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
const
BlockInfoPadded
<
Kernel_traits
::
THREADS
>
binfo
(
params
,
bidb
,
bidh
,
tidx
);
if
(
binfo
.
stop_early
()
)
return
;
Mask
<
Cta_tile_p
>
mask
(
params
,
binfo
,
tidx
);
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
Philox
ph
(
std
::
get
<
0
>
(
seeds
),
binfo
.
tidx_global
,
std
::
get
<
1
>
(
seeds
));
static_assert
(
2
*
Mma_tile_p
::
MMAS_M
*
4
*
Mma_tile_p
::
MMAS_N
<=
64
);
// Allocate the global memory tile loader for K.
Gmem_tile_k
gmem_k
(
params
,
1
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for K.
Smem_tile_k
smem_k
(
&
smem_
[
0
],
tidx
);
// Allocate the global memory tile loader for V.
Gmem_tile_v
gmem_v
(
params
,
2
,
binfo
,
tidx
);
// The base pointer of smem_v;
char
*
smem_v_
=
nullptr
;
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
smem_v_
=
&
smem_
[
0
];
}
else
{
smem_v_
=
&
smem_
[
Smem_tile_q
::
BYTES_PER_TILE
+
Smem_tile_k
::
BYTES_PER_TILE
];
}
static_assert
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
);
static_assert
(
Smem_tile_k
::
BYTES_PER_TILE
==
Smem_tile_v
::
BYTES_PER_TILE
);
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
Smem_tile_v
smem_v
(
smem_v_
,
tidx
);
// Allocate the global memory tile loader for Q.
Gmem_tile_q
gmem_q
(
params
,
0
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for Q.
Smem_tile_q
smem_q
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Allocate the global memory tile loader for O.
Gmem_tile_o
gmem_o
(
params
,
binfo
,
tidx
);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
Smem_tile_o
smem_o
(
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
],
tidx
);
// Trigger the loads for Q.
gmem_q
.
load
(
smem_q
);
// Trigger the loads for K.
gmem_k
.
load
(
smem_k
);
// Trigger the loads for K.
gmem_v
.
load
(
smem_v
);
// Commit the data for Q and K to shared memory.
gmem_q
.
commit
(
smem_q
);
gmem_k
.
commit
(
smem_k
);
// Commit the data for V to shared memory.
if
(
!
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
gmem_v
.
commit
(
smem_v
);
}
// Make sure the data is in shared memory.
__syncthreads
();
// Load the fragments for Q.
typename
Smem_tile_q
::
Fragment
frag_q
[
1
][
Mma_tile_p
::
MMAS_M
];
// Load the fragments for K. We keep the data in registers during the entire kernel.
typename
Smem_tile_k
::
Fragment
frag_k
[
Mma_tile_p
::
MMAS_K
][
Mma_tile_p
::
MMAS_N
];
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
smem_k
.
load
(
frag_k
[
ki
],
ki
);
}
// Commit the data for V to shared memory if it has not been done already.
if
(
Kernel_traits
::
SHARE_SMEM_FOR_K_AND_V
)
{
// Make sure we are done loading the fragments for K.
__syncthreads
();
// Commit the data to shared memory for V.
gmem_v
.
commit
(
smem_v
);
}
enum
{
BITS_PER_ELT_S
=
sizeof
(
typename
fmha
::
A_type
)
*
8
};
Gmem_tile_s
gmem_s
(
params
.
s_ptr
,
params
,
tidx
);
// Create the object to do the softmax.
using
Softmax
=
fmha
::
Softmax
<
Cta_tile_p
,
Kernel_traits
>
;
Softmax
softmax
(
params
,
&
smem_
[
Smem_tile_v
::
BYTES_PER_TILE
+
Smem_tile_o
::
BYTES_PER_TILE
],
bidb
,
tidx
);
constexpr
int
SMEM_BYTES_SOFTMAX
=
Softmax
::
ELEMENTS
*
sizeof
(
float
);
static_assert
(
SMEM_BYTES_SOFTMAX
==
Cta_tile_p
::
M
*
Cta_tile_p
::
WARPS_N
*
sizeof
(
float
));
enum
{
THREADS_PER_ROW
=
32
};
const
float
pinv
=
1.
f
/
params
.
p_dropout
;
// Load over the entire sequence length.
for
(
int
loop
=
0
,
outer
=
0
;
loop
<
Cta_tile_p
::
N
;
loop
+=
Cta_tile_p
::
M
,
outer
++
)
{
if
(
loop
>=
binfo
.
actual_seqlen
)
break
;
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_p
[
Mma_tile_p
::
MMAS_M
][
Mma_tile_p
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_p
::
WARPS_K
>::
apply
(
acc_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_p
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of Q values.
smem_q
.
load
(
frag_q
[
0
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_p
,
frag_q
[
0
],
frag_k
[
ki
]);
}
// Load the mask for that iteration.
mask
.
load
(
outer
);
// Convert from the accumulator typ e to FP32 for Softmax.
softmax
.
unpack
(
acc_p
);
// Apply the mask.
softmax
.
apply_mask
(
mask
);
static_assert
(
2
*
Mma_tile_p
::
MMAS_M
*
4
*
Mma_tile_p
::
MMAS_N
<=
64
);
// Compute the max.
float
p_max
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Max_
>(
p_max
);
// Make sure we are done reading shared memory.
__syncthreads
();
// Compute the exponential value.
softmax
.
apply_exp
(
p_max
);
// Compute the sum.
float
p_sum
[
Mma_tile_p
::
MMAS_M
*
2
];
softmax
.
template
reduce
<
fmha
::
Sum_
>(
p_sum
);
// Finalize softmax on the accumulators of P^T.
softmax
.
scale
(
p_sum
);
__syncthreads
();
if
(
Is_training
)
{
auto
encode_dropout
=
[](
bool
keep
,
float
val
)
{
return
keep
?
val
:
-
val
;
};
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_p
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
2
;
ii
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
Mma_tile_p
::
MMAS_N
;
ni
++
)
{
float4
tmp
=
uniform4
(
ph
());
// We encode the dropout pattern in the sign bit of the non-negative softmax to distinguish from
// pre-existing zeros
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]
=
encode_dropout
(
tmp
.
x
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
0
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]
=
encode_dropout
(
tmp
.
y
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
1
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]
=
encode_dropout
(
tmp
.
z
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
2
]);
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]
=
encode_dropout
(
tmp
.
w
<=
params
.
p_dropout
,
softmax
.
elt_
[
2
*
mi
+
ii
][
4
*
ni
+
3
]);
}
}
}
gmem_s
.
store
(
softmax
.
elt_
,
mask
);
gmem_s
.
move
();
}
// Trigger the load for the next Q values.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
smem_q
.
move_to_next_write_buffer
();
gmem_q
.
move
();
gmem_q
.
load
(
smem_q
);
}
typename
Smem_tile_v
::
Fragment
frag_v
[
1
][
Mma_tile_o
::
MMAS_N
];
using
Frag_p
=
fmha
::
Fragment_a
<
fmha
::
Row
>
;
Frag_p
frag_p
[
Mma_tile_o
::
MMAS_K
][
Mma_tile_o
::
MMAS_M
];
softmax
.
pack
(
frag_p
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
ki
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
Mma_tile_o
::
MMAS_M
;
mi
++
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Frag_p
::
NUM_REGS
;
ii
++
)
{
//"Apply" the dropout.
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hmul2
(
frag_p
[
ki
][
mi
].
reg
(
ii
),
params
.
scale_dropout
);
frag_p
[
ki
][
mi
].
reg
(
ii
)
=
fmha
::
hrelu2
(
frag_p
[
ki
][
mi
].
reg
(
ii
));
}
}
}
// Declare the accumulators for the 1st gemm.
fmha
::
Fragment_accumulator
acc_o
[
Mma_tile_o
::
MMAS_M
][
Mma_tile_o
::
MMAS_N
];
fmha
::
Clear_accumulator
<
typename
fmha
::
Accumulator_type
,
Cta_tile_o
::
WARPS_K
>::
apply
(
acc_o
);
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
// Trigger the load from shared memory for the next series of V values.
smem_v
.
load
(
frag_v
[
0
],
ki
);
// Do the math for the values already in registers.
fmha
::
gemm
(
acc_o
,
frag_p
[
ki
],
frag_v
[
0
]);
}
// Loop over MMAS_M.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Gmem_tile_o
::
LOOPS
;
++
ii
)
{
// Swizzle the elements and do the final reduction.
smem_o
.
store
(
acc_o
,
ii
);
// Make sure the data is in shared memory.
__syncthreads
();
// Load from shared memory.
uint4
out
[
Gmem_tile_o
::
STGS_PER_LOOP
];
smem_o
.
load
(
out
);
// Always sync after last iter: shared smem_q and smem_o!
__syncthreads
();
// Output the values.
gmem_o
.
store
(
out
,
ii
);
}
// same smem as o
// Move to the next part of the output.
gmem_o
.
move
();
// Commit the values for Q into shared memory.
if
(
loop
+
Cta_tile_p
::
M
<
Cta_tile_p
::
N
)
{
gmem_q
.
commit
(
smem_q
);
}
// Make sure the data is in shared memory.
__syncthreads
();
}
// Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace fmha
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment