Unverified Commit 6c7d9992 authored by q.yao's avatar q.yao Committed by GitHub
Browse files

support fmha (#9)

* support fmha

* update sm by cudaarch

* update ldscript path

* clang-format

* clang-format

---------
parent 62c60806
cmake_minimum_required(VERSION 3.8)
add_library(llama_fmha STATIC llama_flash_attention_kernel.cu)
target_include_directories(llama_fmha PRIVATE ${CUTLASS_DIR}/examples)
target_link_libraries(llama_fmha PRIVATE nvidia::cutlass::cutlass)
set_property(TARGET llama_fmha PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET llama_fmha PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/functional.h"
#include "cutlass/gemm/warp/mma_simt_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
#include "cutlass/matrix_shape.h"
/*
TensorCores have different accumulator layouts.
This file provides a class to easily map the accumulator
i-th element with the corresponding matrix row/col.
*/
template<typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSm80 {
static_assert(cutlass::platform::is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
static int const kElementsPerAccess = InstructionShape::kN / 4;
static int const kRowsPerTile = 8;
static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile;
static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset)
{
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
return cutlass::MatrixCoord(quad + tile_offset.row() * Shape::kRow,
lane_in_quad * kElementsPerAccess + tile_offset.column() * Shape::kColumn);
}
template<typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, FA beginRow, FB op, FC endRow)
{
// See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < kAccumulatorRows; ++row) {
int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + row * kRowsPerTile + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
int mma_accum_start =
kAccumulatorRows * kElementsPerAccess * (mma_n * Policy::MmaIterations::kRow + mma_m);
CUTLASS_PRAGMA_UNROLL
for (int col = 0; col < kElementsPerAccess; ++col) {
int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + col + lane_offset.column();
int idx = mma_accum_start + row * kElementsPerAccess + col;
op(accum_m, accum_n, idx);
}
}
endRow(accum_m);
}
}
}
template<typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
{
// In each warp, 4 threads will work on the same row
// - the ones with the same `quad`
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1);
myValue = fn(myValue, otherV);
otherV = __shfl_xor_sync(0xffffffff, myValue, 2);
myValue = fn(myValue, otherV);
int lane_in_quad = (lane_id & 3);
return lane_in_quad == 0;
}
};
template<typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSm70 {
static_assert(cutlass::platform::is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
using OpDelta = typename T::OpDelta;
using Shape = typename T::Shape;
using Element = accum_t;
static int const kElementsPerPartial = 4;
using EleShapePerPatial = typename cutlass::platform::conditional<cutlass::platform::is_same<Element, float>::value,
cutlass::MatrixShape<2, 2>,
cutlass::MatrixShape<1, 4>>::type;
static int const kElementsPerMma = 8;
static int const kAccumulatorPatials = 2;
using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>;
static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset)
{
int quad = (lane_id >> 2);
int lane_in_quad = (lane_id & 3);
int accum_m, accum_n;
if (cutlass::platform::is_same<Element, float>::value) {
// (quad[2],quad[0])+lane_in_quad[0]
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1);
// (quad[1])+lane_in_quad[1]
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + (lane_in_quad & 2);
}
else {
accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + lane_in_quad; // (quad[2],quad[0])
accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials;
}
return cutlass::MatrixCoord(accum_m + tile_offset.row() * Shape::kRow,
accum_n + tile_offset.column() * Shape::kColumn);
}
template<typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
{
static_assert(cutlass::platform::is_same<Element, float>::value, "update to support non-float accum");
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
// T0 & T2 share same line within a quad
auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1);
myValue = fn(myValue, otherV);
// quad 0 and quad 2 are on the same lines
otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3);
myValue = fn(myValue, otherV);
return (lane_id & ((1 << 1) | (1 << 3))) == 0;
}
template<typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, FA beginRow, FB op, FC endRow)
{
CUTLASS_PRAGMA_UNROLL
for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) {
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < EleShapePerPatial::kRow; ++m) {
int accum_m = tile_m * Policy::InterleavedTile::kRow + mma_m * QuadShapePerPatialMma::kRow + m * 2
+ lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) {
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) {
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < kAccumulatorPatials; ++p) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < EleShapePerPatial::kColumn; ++n) {
int mma_accum_start = (((tile_n * Policy::TileIterations::kRow + tile_m)
* Policy::MmaIterations::kColumn
+ mma_n)
* Policy::MmaIterations::kRow
+ mma_m)
* kElementsPerMma;
int accum_n = tile_n * Policy::InterleavedTile::kColumn
+ mma_n * QuadShapePerPatialMma::kColumn
+ p * Policy::InterleavedTile::kColumn / 2 + n + lane_offset.column();
int idx =
mma_accum_start + p * kElementsPerPartial + m * EleShapePerPatial::kColumn + n;
op(accum_m, accum_n, idx);
}
}
}
}
endRow(accum_m);
}
}
}
}
};
template<typename T, typename accum_t, int kWarpSize>
struct AccumLambdaIteratorSimt {
using Policy = typename T::Policy;
using Iterations = typename T::Iterations;
using Element = typename T::Element;
using Delta = typename T::Delta;
using Shape = typename T::Shape;
static_assert(cutlass::platform::is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
template<typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn)
{
CUTLASS_PRAGMA_UNROLL
for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) {
auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit);
myValue = fn(myValue, otherV);
}
return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0;
}
template<typename FA, typename FB, typename FC>
CUTLASS_DEVICE static void iterateRows(cutlass::MatrixCoord& lane_offset, FA beginRow, FB op, FC endRow)
{
CUTLASS_PRAGMA_UNROLL
for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) {
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) {
int accum_m = mma_m * Delta::kRow + m + lane_offset.row();
beginRow(accum_m);
CUTLASS_PRAGMA_UNROLL
for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) {
int accum_n = mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + lane_offset.column();
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) {
int idx = n
+ Policy::LaneMmaShape::kN
* (mma_n + Iterations::kColumn * (m + mma_m * Policy::LaneMmaShape::kM));
op(accum_m, accum_n + n, idx);
}
}
endRow(accum_m);
}
}
}
static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset(int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset)
{
static_assert(
cutlass::platform::is_same<typename Policy::LaneLayout, cutlass::layout::RowMajorInterleaved<1>>::value,
"");
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
cutlass::MatrixCoord lane_offset =
lane_layout.inverse(lane_id) * cutlass::MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN);
return lane_offset + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn);
}
};
template<typename T, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator;
// Simt
template<typename S, typename P, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::
MmaSimtTileIterator<S, cutlass::gemm::Operand::kC, accum_t, cutlass::layout::RowMajor, P, 1, 1>,
accum_t,
kWarpSize> {
using WarpIterator = typename cutlass::gemm::warp::
MmaSimtTileIterator<S, cutlass::gemm::Operand::kC, accum_t, cutlass::layout::RowMajor, P, 1, 1>;
using Iterator = AccumLambdaIteratorSimt<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Volta
template<typename S1, typename S2, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::
MmaVoltaTensorOpAccumulatorTileIterator<S1, accum_t, cutlass::layout::RowMajor, S2, cutlass::MatrixShape<1, 1>>,
accum_t,
kWarpSize> {
using WarpIterator = typename cutlass::gemm::warp::
MmaVoltaTensorOpAccumulatorTileIterator<S1, accum_t, cutlass::layout::RowMajor, S2, cutlass::MatrixShape<1, 1>>;
using Iterator = AccumLambdaIteratorSm70<WarpIterator, accum_t, kWarpSize>;
};
// TensorOp - Sm75+
template<typename S1, typename S2, typename S3, typename accum_t, int kWarpSize>
struct DefaultMmaAccumLambdaIterator<
cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator<S1, accum_t, cutlass::layout::RowMajor, S2, S3>,
accum_t,
kWarpSize> {
using WarpIterator = typename cutlass::gemm::warp::
MmaTensorOpAccumulatorTileIterator<S1, accum_t, cutlass::layout::RowMajor, S2, S3>;
using Iterator = AccumLambdaIteratorSm80<WarpIterator, accum_t, kWarpSize>;
};
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#include <cutlass/cutlass.h>
#include "cutlass/aligned_buffer.h"
#include "cutlass/array.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/pitch_linear.h"
#include "cutlass/numeric_types.h"
#include "cutlass/transform/pitch_linear_thread_map.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "cutlass/transform/threadblock/regular_tile_iterator.h"
template<typename scalar_t, // scalar type
typename ThreadblockTileShape, // size of tile to load
int Threads, // number of participating threads
int ElementsPerAccess> // thread access width in elements
class TileSmemLoader {
public:
using SmemTile = cutlass::AlignedBuffer<scalar_t, ThreadblockTileShape::kCount>;
using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap<
cutlass::layout::PitchLinearShape<ThreadblockTileShape::kColumn, // contiguous
ThreadblockTileShape::kRow>, // strided
Threads, // Threads
ElementsPerAccess>; // ElementsPerAccess
using GmemTileIterator =
cutlass::transform::threadblock::PredicatedTileIterator<ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator<ThreadblockTileShape, // Shape
scalar_t, // Element
cutlass::layout::RowMajor, // Layout
0, // AdvanceRank
ThreadMap>; // ThreadMap
using Fragment = typename GmemTileIterator::Fragment;
/// load a tile from global memory into shared memory
CUTLASS_DEVICE
static void load(GmemTileIterator tile_load_iter, SmemTileIterator tile_store_iter)
{
Fragment tb_frag;
tb_frag.clear();
tile_load_iter.load(tb_frag);
tile_store_iter.store(tb_frag);
__syncthreads();
}
};
...@@ -153,7 +153,7 @@ set_target_properties( ...@@ -153,7 +153,7 @@ set_target_properties(
INSTALL_RPATH_USE_LINK_PATH FALSE INSTALL_RPATH_USE_LINK_PATH FALSE
INSTALL_RPATH "$\{ORIGIN\}" INSTALL_RPATH "$\{ORIGIN\}"
LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_fastertransformer.ldscript LINK_DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/libtriton_fastertransformer.ldscript
LINK_FLAGS "-Wl,--no-as-needed,--version-script libtriton_fastertransformer.ldscript" LINK_FLAGS "-Wl,--no-as-needed,--version-script ${CMAKE_CURRENT_BINARY_DIR}/libtriton_fastertransformer.ldscript"
) )
# Need to turn off unused-but-set-variable due to Torchvision # Need to turn off unused-but-set-variable due to Torchvision
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment