Unverified Commit 013f0c4f authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

CMake build, allowing parent build (#19)

parent 344c988d
Pipeline #2020 failed with stages
in 0 seconds
cmake_minimum_required(VERSION 3.26)
project(vllm_flash_attn LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_EXTENSIONS OFF)
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")
# Supported python versions. These should be kept in sync with setup.py.
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "8.0;8.6;8.9;9.0")
# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
#
# Supported/expected torch versions for CUDA/ROCm.
#
# Currently, having an incorrect pytorch version results in a warning
# rather than an error.
#
# Note: these should be kept in sync with the torch version in setup.py.
# Likely should also be in sync with the vLLM version.
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
find_python_constrained_versions(${PYTHON_SUPPORTED_VERSIONS})
if (VLLM_PARENT_BUILD)
# vLLM extracts the supported architectures from the global CMAKE_CUDA_FLAGS, which are set by torch.
# Because CMAKE_CUDA_FLAGS has been modified, we cannot use the same logic.
# Hence, we just use the parent's VLLM_GPU_ARCHES and VLLM_GPU_FLAGS.
message(STATUS "Building vllm-flash-attn inside vLLM. Skipping flag detection and relying on parent build.")
macro(check_found NAME VAR)
if (NOT ${VAR})
message(FATAL_ERROR "${NAME} must have been found by parent.")
endif ()
endmacro()
check_found("Torch" TORCH_FOUND)
set(VLLM_FA_GPU_FLAGS ${VLLM_GPU_FLAGS})
set(VLLM_FA_GPU_ARCHES ${VLLM_GPU_ARCHES})
# Allow direct override of GPU architectures.
# These have to be in CMake syntax (75-real, 89-virtual, etc).
if (DEFINED ENV{VLLM_FA_CMAKE_GPU_ARCHES})
message(STATUS "Overriding GPU architectures to $ENV{VLLM_FA_CMAKE_GPU_ARCHES}")
set(VLLM_FA_GPU_ARCHES $ENV{VLLM_FA_CMAKE_GPU_ARCHES})
# Generally, we want to build with a subset of the parent arches.
foreach (VLLM_FA_GPU_ARCH IN LISTS VLLM_FA_GPU_ARCHES)
if (NOT VLLM_FA_GPU_ARCH IN_LIST VLLM_GPU_ARCHES)
message(WARNING "Using GPU architecture ${VLLM_FA_GPU_ARCH}, "
"which is not included in the parent list.")
endif ()
endforeach ()
endif ()
else ()
message(STATUS "Standalone vllm-flash-attn build.")
#
# Update cmake's `CMAKE_PREFIX_PATH` with torch location.
#
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
message(DEBUG "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}")
#
# Import torch cmake configuration.
# Torch also imports CUDA (and partially HIP) languages with some customizations,
# so there is no need to do this explicitly with check_language/enable_language,
# etc.
#
find_package(Torch REQUIRED)
#
# Set up GPU language and check the torch version and warn if it isn't
# what is expected.
#
if (NOT HIP_FOUND AND CUDA_FOUND)
set(VLLM_GPU_LANG "CUDA")
# Check CUDA is at least 11.6
if (CUDA_VERSION VERSION_LESS 11.6)
message(FATAL_ERROR "CUDA version 11.6 or greater is required.")
endif ()
if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA})
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} "
"expected for CUDA build, saw ${Torch_VERSION} instead.")
endif ()
elseif (HIP_FOUND)
message(FATAL_ERROR "ROCm build is not currently supported for vllm-flash-attn.")
set(VLLM_GPU_LANG "HIP")
# Importing torch recognizes and sets up some HIP/ROCm configuration but does
# not let cmake recognize .hip files. In order to get cmake to understand the
# .hip extension automatically, HIP must be enabled explicitly.
enable_language(HIP)
# ROCm 5.X and 6.X
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
"expected for ROCm build, saw ${Torch_VERSION} instead.")
endif ()
else ()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
endif ()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# the supported versions for the current language.
# The final set of arches is stored in `VLLM_GPU_ARCHES`.
#
override_gpu_arches(VLLM_FA_GPU_ARCHES
${VLLM_GPU_LANG}
"${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
#
# Query torch for additional GPU compilation flags for the given
# `VLLM_GPU_LANG`.
# The final set of arches is stored in `VLLM_FA_GPU_FLAGS`.
#
get_torch_gpu_compiler_flags(VLLM_FA_GPU_FLAGS ${VLLM_GPU_LANG})
#
# Set nvcc parallelism.
#
if (NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_FA_GPU_FLAGS "--threads=${NVCC_THREADS}")
endif ()
endif ()
# Other flags
list(APPEND VLLM_FA_GPU_FLAGS --expt-relaxed-constexpr --expt-extended-lambda --use_fast_math)
# Replace instead of appending, nvcc doesn't like duplicate -O flags.
string(REPLACE "-O2" "-O3" CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${CMAKE_CUDA_FLAGS_RELWITHDEBINFO}")
#
# _C extension
#
file(GLOB FLASH_ATTN_GEN_SRCS "csrc/flash_attn/src/flash_fwd_*.cu")
message(DEBUG "FLASH_ATTN_GEN_SRCS: ${FLASH_ATTN_GEN_SRCS}")
define_gpu_extension_target(
vllm_flash_attn_c
DESTINATION vllm_flash_attn
LANGUAGE ${VLLM_GPU_LANG}
SOURCES csrc/flash_attn/flash_api.cpp ${FLASH_ATTN_GEN_SRCS}
COMPILE_FLAGS ${VLLM_FA_GPU_FLAGS}
ARCHITECTURES ${VLLM_FA_GPU_ARCHES}
USE_SABI 3
WITH_SOABI
)
target_include_directories(vllm_flash_attn_c PRIVATE
csrc/flash_attn
csrc/flash_attn/src
csrc/cutlass/include)
# custom definitions
target_compile_definitions(vllm_flash_attn_c PRIVATE
# FLASHATTENTION_DISABLE_BACKWARD
FLASHATTENTION_DISABLE_DROPOUT
# FLASHATTENTION_DISABLE_ALIBI
# FLASHATTENTION_DISABLE_SOFTCAP
FLASHATTENTION_DISABLE_UNEVEN_K
# FLASHATTENTION_DISABLE_LOCAL
)
# Check for old generator
find_file(OLD_GENERATOR_FILE "ATen/CUDAGeneratorImpl.h" ${TORCH_INCLUDE_DIRS} NO_DEFAULT_PATH)
if (OLD_GENERATOR_FILE)
target_compile_definitions(vllm_flash_attn_c PRIVATE -DOLD_GENERATOR_PATH)
endif ()
# This file is taken from github.com/vllm-project/vllm/cmake/utils.cmake
# It contains utility functions for building PyTorch extensions with GPU support.
# The only modification is finding the `Python` package with a specific version.
#
# THIS MACRO WAS MODIFIED
# Attempt to find the python package and verifies that it is one of the `SUPPORTED_VERSIONS`.
# To customize which python gets found, set the `Python_EXECUTABLE` variable.
macro (find_python_constrained_versions SUPPORTED_VERSIONS)
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
if (NOT Python_FOUND)
if (Python_EXECUTABLE)
message(FATAL_ERROR "Unable to find python matching: ${Python_EXECUTABLE}.")
else()
message(FATAL_ERROR "Unable to find python.")
endif ()
endif()
set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
message(FATAL_ERROR
"Python version (${_VER}) is not one of the supported versions: "
"${_SUPPORTED_VERSIONS_LIST}.")
endif()
endmacro()
#
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
# has trailing whitespace stripped. If an error is encountered when running
# python, a fatal message `ERR_MSG` is issued.
#
function (run_python OUT EXPR ERR_MSG)
execute_process(
COMMAND
"${Python_EXECUTABLE}" "-c" "${EXPR}"
OUTPUT_VARIABLE PYTHON_OUT
RESULT_VARIABLE PYTHON_ERROR_CODE
ERROR_VARIABLE PYTHON_STDERR
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT PYTHON_ERROR_CODE EQUAL 0)
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
endif()
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
endfunction()
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
macro (append_cmake_prefix_path PKG EXPR)
run_python(_PREFIX_PATH
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
endmacro()
#
# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
# of CUDA source files. The names of the corresponding "hipified" sources are
# stored in `OUT_SRCS`.
#
function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
#
# Split into C++ and non-C++ (i.e. CUDA) sources.
#
set(SRCS ${ORIG_SRCS})
set(CXX_SRCS ${ORIG_SRCS})
list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
#
# Generate ROCm/HIP source file names from CUDA file names.
# Since HIP files are generated code, they will appear in the build area
# `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
#
set(HIP_SRCS)
foreach (SRC ${SRCS})
string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
endforeach()
set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc)
add_custom_target(
hipify${NAME}
COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
BYPRODUCTS ${HIP_SRCS}
COMMENT "Running hipify on ${NAME} extension source files.")
# Swap out original extension sources with hipified sources.
list(APPEND HIP_SRCS ${CXX_SRCS})
set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
endfunction()
#
# Get additional GPU compiler flags from torch.
#
function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
if (${GPU_LANG} STREQUAL "CUDA")
#
# Get common NVCC flags from torch.
#
run_python(GPU_FLAGS
"from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
"Failed to determine torch nvcc compiler flags")
if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
list(APPEND GPU_FLAGS "-DENABLE_FP8")
endif()
if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
list(REMOVE_ITEM GPU_FLAGS
"-D__CUDA_NO_HALF_OPERATORS__"
"-D__CUDA_NO_HALF_CONVERSIONS__"
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
"-D__CUDA_NO_HALF2_OPERATORS__")
endif()
elseif(${GPU_LANG} STREQUAL "HIP")
#
# Get common HIP/HIPCC flags from torch.
#
run_python(GPU_FLAGS
"import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
"Failed to determine torch nvcc compiler flags")
list(APPEND GPU_FLAGS
"-DUSE_ROCM"
"-DENABLE_FP8"
"-U__HIP_NO_HALF_CONVERSIONS__"
"-U__HIP_NO_HALF_OPERATORS__"
"-fno-gpu-rdc")
endif()
set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
endfunction()
# Macro for converting a `gencode` version number to a cmake version number.
macro(string_to_ver OUT_VER IN_STR)
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
endmacro()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
# `GPU_ARCHES`.
#
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
#
macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
if (${GPU_LANG} STREQUAL "HIP")
#
# `GPU_ARCHES` controls the `--offload-arch` flags.
#
# If PYTORCH_ROCM_ARCH env variable exists, then we take it as a list,
# if not, then we use CMAKE_HIP_ARCHITECTURES which was generated by calling
# "rocm_agent_enumerator" in "enable_language(HIP)"
# (in file Modules/CMakeDetermineHIPCompiler.cmake)
#
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
set(HIP_ARCHITECTURES $ENV{PYTORCH_ROCM_ARCH})
else()
set(HIP_ARCHITECTURES ${CMAKE_HIP_ARCHITECTURES})
endif()
#
# Find the intersection of the supported + detected architectures to
# set the module architecture flags.
#
set(${GPU_ARCHES})
foreach (_ARCH ${HIP_ARCHITECTURES})
if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
list(APPEND ${GPU_ARCHES} ${_ARCH})
endif()
endforeach()
if(NOT ${GPU_ARCHES})
message(FATAL_ERROR
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
endif()
elseif(${GPU_LANG} STREQUAL "CUDA")
#
# Setup/process CUDA arch flags.
#
# The torch cmake setup hardcodes the detected architecture flags in
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
# can't modified on a per-target basis.
# So, all the `-gencode` flags need to be extracted and removed from
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
# Since it's not possible to use `target_compiler_options` for adding target
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
# must be used instead. This requires repackaging the architecture flags
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
#
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
# to one of the other nvcc options to specify architectures.
#
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
# detected architectures.
#
message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
# If this error is triggered, it might mean that torch has changed how it sets
# up nvcc architecture code generation flags.
if (NOT _CUDA_ARCH_FLAGS)
message(FATAL_ERROR
"Could not find any architecture related code generation flags in "
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
endif()
message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
# Initialize the architecture lists to empty.
set(${GPU_ARCHES})
# Process each `gencode` flag.
foreach(_ARCH ${_CUDA_ARCH_FLAGS})
# For each flag, extract the version number and whether it refers to PTX
# or native code.
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
# for that match.
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
if (_COMPUTE)
set(_COMPUTE ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
if (_SM)
set(_SM ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
if (_CODE)
set(_CODE ${CMAKE_MATCH_1})
endif()
# Make sure the virtual architecture can be matched.
if (NOT _COMPUTE)
message(FATAL_ERROR
"Could not determine virtual architecture from: ${_ARCH}.")
endif()
# One of sm_ or compute_ must exist.
if ((NOT _SM) AND (NOT _CODE))
message(FATAL_ERROR
"Could not determine a codegen architecture from: ${_ARCH}.")
endif()
if (_SM)
# -real suffix let CMake to only generate elf code for the kernels.
# we want this, otherwise the added ptx (default) will increase binary size.
set(_VIRT "-real")
set(_CODE_ARCH ${_SM})
else()
# -virtual suffix let CMake to generate ptx code for the kernels.
set(_VIRT "-virtual")
set(_CODE_ARCH ${_CODE})
endif()
# Check if the current version is in the supported arch list.
string_to_ver(_CODE_VER ${_CODE_ARCH})
if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
message(STATUS "discarding unsupported CUDA arch ${_CODE_VER}.")
continue()
endif()
# Add it to the arch list.
list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
endforeach()
endif()
message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
endmacro()
#
# Define a target named `GPU_MOD_NAME` for a single extension. The
# arguments are:
#
# DESTINATION <dest> - Module destination directory.
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
# etc.
# SOURCES <sources> - List of source files relative to CMakeLists.txt
# directory.
#
# Optional arguments:
#
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake
# format.
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
# and `CMAKE_HIP_ARCHITECTURES` for more info.
# ARCHITECTURES will use cmake's defaults if
# not provided.
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
# LIBRARIES <libraries> - Extra link libraries.
# WITH_SOABI - Generate library with python SOABI suffix name.
# USE_SABI <version> - Use python stable api <version>
#
# Note: optimization level/debug info is set via cmake build type.
#
function (define_gpu_extension_target GPU_MOD_NAME)
cmake_parse_arguments(PARSE_ARGV 1
GPU
"WITH_SOABI"
"DESTINATION;LANGUAGE;USE_SABI"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
# Add hipify preprocessing step when building with HIP/ROCm.
if (GPU_LANGUAGE STREQUAL "HIP")
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
endif()
if (GPU_WITH_SOABI)
set(GPU_WITH_SOABI WITH_SOABI)
else()
set(GPU_WITH_SOABI)
endif()
if (GPU_USE_SABI)
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
else()
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
endif()
if (GPU_LANGUAGE STREQUAL "HIP")
# Make this target dependent on the hipify preprocessor step.
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
endif()
if (GPU_ARCHITECTURES)
set_target_properties(${GPU_MOD_NAME} PROPERTIES
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
endif()
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
target_compile_options(${GPU_MOD_NAME} PRIVATE
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
${GPU_INCLUDE_DIRECTORIES})
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
# dependencies that are not necessary and may not be installed.
if (GPU_LANGUAGE STREQUAL "CUDA")
if ("${CUDA_CUDA_LIB}" STREQUAL "")
set(CUDA_CUDA_LIB "${CUDA_CUDA_LIBRARY}")
endif()
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB}
${CUDA_LIBRARIES})
else()
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
endif()
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
endfunction()
......@@ -3,7 +3,8 @@
******************************************************************************/
// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers.
#include <torch/python.h>
#include "registration.h"
#include <torch/library.h>
#include <torch/nn/functional.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
......@@ -241,7 +242,7 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(Flash_fwd_params &params,
return std::make_tuple(softmax_lse_accum, out_accum);
}
void set_params_alibi(Flash_fwd_params &params, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
void set_params_alibi(Flash_fwd_params &params, const c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
#ifdef FLASHATTENTION_DISABLE_ALIBI
TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
params.alibi_slopes_ptr = nullptr;
......@@ -264,14 +265,14 @@ std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
const c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const double p_dropout,
const double softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
int64_t window_size_left,
int64_t window_size_right,
const double softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
......@@ -452,21 +453,21 @@ std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
const c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
const c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int64_t max_seqlen_q,
const int64_t max_seqlen_k,
const double p_dropout,
const double softmax_scale,
const bool zero_tensors,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
int64_t window_size_left,
int64_t window_size_right,
const double softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
......@@ -708,23 +709,23 @@ std::vector<at::Tensor>
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const float softmax_scale,
const c10::optional<at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
const c10::optional<at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
const c10::optional<at::Tensor> &seqlens_k_, // batch_size
const c10::optional<at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
const c10::optional<at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
const c10::optional<at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
const c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
const c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
const double softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
int64_t window_size_left,
int64_t window_size_right,
const double softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits
) {
int64_t num_splits
) {
auto dprops = at::cuda::getCurrentDeviceProperties();
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
......@@ -983,11 +984,26 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
return {out, softmax_lse};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "FlashAttention";
m.def("fwd", &mha_fwd, "Forward pass");
m.def("varlen_fwd", &mha_varlen_fwd, "Forward pass (variable length)");
// m.def("bwd", &mha_bwd, "Backward pass");
// m.def("varlen_bwd", &mha_varlen_bwd, "Backward pass (variable length)");
m.def("fwd_kvcache", &mha_fwd_kvcache, "Forward pass, with KV-cache");
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor? alibi_slopes, "
"float p_dropout, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
"float softcap, bool return_softmax, Generator? gen)"
"-> Tensor[]");
ops.impl("fwd", torch::kCUDA, &mha_fwd);
ops.def("varlen_fwd(Tensor! q, Tensor k, Tensor v, Tensor!? out, Tensor cu_seqlens_q, "
"Tensor cu_seqlens_k, Tensor? seqused_k, Tensor? block_table, Tensor? alibi_slopes, "
"int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, "
"bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, "
"Generator? gen) -> Tensor[]");
ops.impl("varlen_fwd", torch::kCUDA, &mha_varlen_fwd);
ops.def("fwd_kvcache(Tensor! q, Tensor kcache, Tensor vcache, Tensor? k, Tensor? v, Tensor? seqlens_k, "
"Tensor? rotary_cos, Tensor? rotary_sin, Tensor? cache_batch_idx, Tensor? block_table, Tensor? alibi_slopes, "
"Tensor!? out, float softmax_scale, bool is_causal, int window_size_left, int window_size_right, "
"float softcap, bool is_rotary_interleaved, int num_splits) -> Tensor[]");
ops.impl("fwd_kvcache", torch::kCUDA, &mha_fwd_kvcache);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME);
#pragma once
#include <Python.h>
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
[build-system]
# Should be mirrored in requirements-build.txt
requires = [
"cmake>=3.26",
"ninja",
"packaging",
"setuptools >= 49.4.0",
"torch == 2.4.0",
"wheel",
"jinja2",
]
build-backend = "setuptools.build_meta"
This diff is collapsed.
from typing import Optional, List, Tuple
import torch
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# custom op does not support tuple input
real_window_size: Tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return _flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
window_size=real_window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
block_table=block_table,
)
@flash_attn_varlen_func.register_fake # type: ignore
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Optional[List[int]] = None,
softcap: float = 0.0,
alibi_slopes: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(q)
@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[])
def flash_attn_with_kvcache(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return _flash_attn_with_kvcache(
decode_query,
key_cache,
value_cache,
cache_seqlens=cache_seqlens,
block_table=block_table,
softmax_scale=softmax_scale,
causal=causal,
alibi_slopes=alibi_slopes,
softcap=softcap,
)
@flash_attn_with_kvcache.register_fake # type: ignore
def _(
decode_query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cache_seqlens: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
alibi_slopes: Optional[torch.Tensor] = None,
softcap: float = 0.0,
) -> torch.Tensor:
return torch.empty_like(decode_query)
# Copyright (c) 2023, Tri Dao.
import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import RotaryEmbedding, apply_rotary_emb_func, apply_rotary_emb_qkv_
from transformers.models.gpt_neox.modeling_gpt_neox import RotaryEmbedding as RotaryEmbeddingNeoX
from transformers.models.gpt_neox.modeling_gpt_neox import (
apply_rotary_pos_emb as apply_rotary_pos_emb_neox,
)
from transformers.models.gptj.modeling_gptj import apply_rotary_pos_emb as apply_rotary_pos_emb_gptj
from transformers.models.gptj.modeling_gptj import fixed_pos_embedding
# NeoX-style rotary embedding
@pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary(rotary_emb_fraction, seqlen_offset):
device = "cuda"
dtype = torch.float16
rtol, atol = (1e-3, 5e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen_total = 2048
seqlen = seqlen_total - seqlen_offset
nheads = 16
headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, device=device)
rotary_neox = RotaryEmbeddingNeoX(rotary_dim, seqlen_total, device=device)
# Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
cos_neox, sin_neox = rotary_neox(qkv, seq_len=seqlen_total)
cos_neox, sin_neox = cos_neox.to(dtype=dtype), sin_neox.to(dtype=dtype)
q_pt = (
rearrange(qkv[:, :, 0, :, :rotary_dim], "b s h d -> b h s d")
.detach()
.clone()
.requires_grad_(True)
)
k_pt = (
rearrange(qkv[:, :, 1, :, :rotary_dim], "b s h d -> b h s d")
.detach()
.clone()
.requires_grad_(True)
)
q_neox, k_neox = apply_rotary_pos_emb_neox(q_pt, k_pt, cos_neox, sin_neox, offset=seqlen_offset)
out = rotary(qkv, seqlen_offset=seqlen_offset)
assert torch.allclose(
rotary._cos_cached, cos_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
)
assert torch.allclose(
rotary._sin_cached, sin_neox[..., : rotary_dim // 2].to(dtype=dtype), rtol=rtol, atol=atol
)
assert torch.allclose(
rearrange(q_neox, "b h s d -> b s h d"), out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.allclose(
rearrange(k_neox, "b h s d -> b s h d"), out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol
)
assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
g = torch.randn_like(out)
g_og = g.clone().detach() # Our implementation modifies g inplace
out.backward(g)
q_neox.backward(rearrange(g_og[:, :, 0, :, :rotary_dim], "b s h d -> b h s d"))
k_neox.backward(rearrange(g_og[:, :, 1, :, :rotary_dim], "b s h d -> b h s d"))
assert torch.allclose(
rearrange(q_pt.grad, "b h s d -> b s h d"),
qkv.grad[:, :, 0, :, :rotary_dim],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
rearrange(k_pt.grad, "b h s d -> b s h d"),
qkv.grad[:, :, 1, :, :rotary_dim],
rtol=rtol,
atol=atol,
)
assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])
# GPT-J-style rotary embedding
@pytest.mark.parametrize("seqlen_offset", [0, 711])
@pytest.mark.parametrize("rotary_emb_fraction", [0.5, 1.0])
def test_rotary_interleaved(rotary_emb_fraction, seqlen_offset):
device = "cuda"
dtype = torch.float16
rtol, atol = (1e-3, 5e-3)
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen_total = 2048
seqlen = seqlen_total - seqlen_offset
nheads = 16
headdim = 128
rotary_dim = int(headdim * rotary_emb_fraction)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, device=device, dtype=dtype, requires_grad=True
)
qkv_og = qkv.clone().detach() # Our implementation modifies qkv inplace
rotary = RotaryEmbedding(rotary_dim, interleaved=True, device=device)
sincos_gptj = fixed_pos_embedding(qkv[..., :rotary_dim], seq_dim=1, seq_len=seqlen_total)
sincos_gptj = tuple(x.to(dtype=dtype) for x in sincos_gptj)
q_pt = qkv[:, :, 0, :, :rotary_dim].detach().clone().requires_grad_(True)
k_pt = qkv[:, :, 1, :, :rotary_dim].detach().clone().requires_grad_(True)
q_gptj = apply_rotary_pos_emb_gptj(q_pt, sincos_gptj, offset=seqlen_offset)
k_gptj = apply_rotary_pos_emb_gptj(k_pt, sincos_gptj, offset=seqlen_offset)
out = rotary(qkv, seqlen_offset=seqlen_offset)
assert torch.allclose(rotary._cos_cached, sincos_gptj[1], rtol=rtol, atol=atol)
assert torch.allclose(rotary._sin_cached, sincos_gptj[0], rtol=rtol, atol=atol)
assert torch.allclose(q_gptj, out[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.allclose(k_gptj, out[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.equal(out[:, :, 0:2, :, rotary_dim:], qkv_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(out[:, :, 2], qkv_og[:, :, 2])
g = torch.randn_like(out)
g_og = g.clone().detach() # Our implementation modifies g inplace
out.backward(g)
q_gptj.backward(g_og[:, :, 0, :, :rotary_dim])
k_gptj.backward(g_og[:, :, 1, :, :rotary_dim])
assert torch.allclose(q_pt.grad, qkv.grad[:, :, 0, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.allclose(k_pt.grad, qkv.grad[:, :, 1, :, :rotary_dim], rtol=rtol, atol=atol)
assert torch.equal(qkv.grad[:, :, 0:2, :, rotary_dim:], g_og[:, :, 0:2, :, rotary_dim:])
assert torch.equal(qkv.grad[:, :, 2], g_og[:, :, 2])
import math
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLoss
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
@pytest.mark.parametrize("return_z_loss", [False, True])
# @pytest.mark.parametrize("lse_square_scale", [1e-2])
@pytest.mark.parametrize("logit_scale", [1.0, 0.7])
# @pytest.mark.parametrize("logit_scale", [1.0])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50257, 128 * 1024]) # test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [12])
def test_cross_entropy_loss(
vocab_size, smoothing, logit_scale, lse_square_scale, return_z_loss, inplace_backward, dtype
):
device = "cuda"
rtol, atol = (1e-5, 1e-6) if dtype == torch.float32 else (1e-3, 1e-4)
# set seed
torch.random.manual_seed(0)
batch_size = 1 if dtype == torch.float32 else 4 # Otherwise OOM
seqlen = 4096 if lse_square_scale == 0.0 and logit_scale == 1.0 else 1024 # Otherwise OOM
x_pt = torch.randn(
batch_size * seqlen, vocab_size, device=device, dtype=dtype, requires_grad=True
)
x = x_pt.detach().clone().requires_grad_()
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
if batch_size * seqlen > 10:
y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing)
model = CrossEntropyLoss(
label_smoothing=smoothing,
logit_scale=logit_scale,
lse_square_scale=lse_square_scale,
return_z_loss=return_z_loss,
inplace_backward=inplace_backward,
)
if return_z_loss:
out, out_z_loss = model(x, y)
else:
out = model(x, y)
x_pt_scaled = (x_pt.float() * logit_scale) if logit_scale != 1.0 else x_pt.float()
out_pt = model_pt(x_pt_scaled, y)
if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt_scaled, dim=-1)
z_loss_pt = lse_square_scale * (lse_pt[y != -100] ** 2).mean()
if return_z_loss:
assert torch.allclose(out_z_loss, z_loss_pt, rtol=rtol, atol=atol)
out_pt += z_loss_pt
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
g = torch.randn_like(out)
out_pt.backward(g)
out.backward(g)
assert torch.allclose(x.grad, x_pt.grad, rtol=rtol, atol=atol)
# Run test with:
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/losses/test_cross_entropy_parallel.py
import math
import pytest
import torch
from apex.transformer import parallel_state, tensor_parallel
from flash_attn.losses.cross_entropy import CrossEntropyLoss
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize(
"dtype", [torch.float16, torch.float32] + ([torch.bfloat16] if is_sm8x else [])
)
# @pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("inplace_backward", [False, True])
# @pytest.mark.parametrize("inplace_backward", [False])
@pytest.mark.parametrize("lse_square_scale", [0.0, 1e-2])
# @pytest.mark.parametrize("lse_square_scale", [0.0])
@pytest.mark.parametrize("logit_scale", [0.7])
# @pytest.mark.parametrize("logit_scale", [1.0])
@pytest.mark.parametrize("smoothing", [0.0, 0.9])
# @pytest.mark.parametrize("smoothing", [0.0])
@pytest.mark.parametrize("vocab_size", [50264, 256 * 1024]) # test vocab larger than 64k for split
# @pytest.mark.parametrize("vocab_size", [50264]) # test vocab larger than 64k for split
# @pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("world_size", [2])
def test_cross_entropy_loss_parallel(
vocab_size, world_size, smoothing, logit_scale, lse_square_scale, inplace_backward, dtype
):
assert vocab_size % world_size == 0
rtol, atol = (
(1e-5, 2e-5)
if dtype == torch.float32
else ((1e-3, 1e-4) if dtype == torch.float16 else (1e-2, 3e-3))
)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
partition_vocab_size = vocab_size // world_size
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 128
x_pt = (
torch.randn(batch_size * seqlen, vocab_size, device=device, dtype=dtype) * 10
).requires_grad_()
x = (
tensor_parallel.scatter_to_tensor_model_parallel_region(x_pt)
.detach()
.clone()
.requires_grad_()
)
y = torch.randint(0, vocab_size, (batch_size * seqlen,), dtype=torch.long, device=device)
y[torch.randperm(batch_size * seqlen)[:10]] = -100
model_pt = torch.nn.CrossEntropyLoss(label_smoothing=smoothing, reduction="none")
model = CrossEntropyLoss(
label_smoothing=smoothing,
logit_scale=logit_scale,
reduction="none",
lse_square_scale=lse_square_scale,
inplace_backward=inplace_backward,
process_group=parallel_state.get_tensor_model_parallel_group(),
)
out = model(x, y)
out_pt = model_pt(x_pt.float() * logit_scale, y)
if lse_square_scale > 0.0:
lse_pt = torch.logsumexp(x_pt.float() * logit_scale, dim=-1)
out_pt += lse_square_scale * lse_pt.square()
out_pt.masked_fill_(y == -100, 0.0)
assert torch.allclose(out, out_pt, rtol=1e-5, atol=1e-6)
g = torch.randn_like(out)
out_pt.backward(g)
out.backward(g)
assert torch.allclose(
x.grad,
x_pt.grad[:, (rank * partition_vocab_size) : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
)
parallel_state.destroy_model_parallel()
# Copyright (c) 2023, Tri Dao.
import os
import time
from pathlib import Path
import torch
import pytest
from einops import rearrange
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from flash_attn.models.gpt import (
GPTLMHeadModel,
combine_state_dicts_tp,
shard_state_dict_tp,
)
from flash_attn.models.baichuan import (
remap_state_dict_hf_baichuan,
baichuan_config_to_gpt2_config,
)
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize(
"model_name",
[
"baichuan-inc/Baichuan-7B",
"baichuan-inc/Baichuan-13B-Base",
"baichuan-inc/Baichuan2-7B-Base",
"baichuan-inc/Baichuan2-13B-Base",
],
)
def test_baichuan_state_dict(model_name):
config = baichuan_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize(
"model_name",
[
"baichuan-inc/Baichuan-7B",
"baichuan-inc/Baichuan-13B-Base",
"baichuan-inc/Baichuan2-7B-Base",
"baichuan-inc/Baichuan2-13B-Base",
],
)
def test_baichuan_optimized(model_name):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = baichuan_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map={"": device},
trust_remote_code=True,
)
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.model(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel_forward"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize(
"model_name",
[
"baichuan-inc/Baichuan-7B",
"baichuan-inc/Baichuan-13B-Base",
"baichuan-inc/Baichuan2-7B-Base",
"baichuan-inc/Baichuan2-13B-Base",
],
)
def test_baichuan_parallel_forward(model_name, world_size):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
from apex.transformer import parallel_state
dtype = torch.float16
config = baichuan_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group)
out = rearrange(out, "(b s) d -> b s d", b=batch_size)
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model
parallel_state.destroy_model_parallel()
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
)
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize(
"model_name", ["baichuan-inc/Baichuan-7B", "baichuan-inc/Baichuan-13B-Base"]
)
def test_baichuan_generation(model_name):
dtype = torch.float16
device = "cuda"
config = baichuan_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 2048
max_length = 2048 + 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
del model_ref
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
model(input_ids) # Warm up
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
del model
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
assert torch.equal(logits_cg, logits)
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation"
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize("model_name", ["baichuan-inc/Baichuan-7B"])
def test_baichuan_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
from apex.transformer import parallel_state
dtype = torch.float16
config = baichuan_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = False
config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
pretrained_state_dict = remap_state_dict_hf_baichuan(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
print("Without CUDA graph")
out = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
cg=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
del model
parallel_state.destroy_model_parallel()
if rank == 0:
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
with torch.inference_mode():
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.inference_mode():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
assert torch.equal(logits_cg, logits)
import re
from collections import OrderedDict
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertForPreTraining as BertForPreTrainingHF
from transformers.models.bert.modeling_bert import BertModel as BertModelHF
from flash_attn.models.bert import (
BertForPreTraining,
BertModel,
inv_remap_state_dict,
remap_state_dict,
)
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_state_dict(model_name):
config = BertConfig.from_pretrained(model_name)
pretrained_state_dict = remap_state_dict(state_dict_from_pretrained(model_name), config)
model = BertForPreTraining(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
def get_hf_models(model_name, config, dtype):
pretrained_state_dict = state_dict_from_pretrained(model_name)
def key_mapping_ln_gamma_beta(key):
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
return key
pretrained_state_dict = OrderedDict(
(key_mapping_ln_gamma_beta(k), v) for k, v in pretrained_state_dict.items()
)
model_hf = BertForPreTrainingHF(config)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
model_hf.load_state_dict(pretrained_state_dict, strict=False)
model_hf.cuda().to(dtype=dtype)
return model_hf
@pytest.mark.parametrize("model_name", ["bert-base-uncased"])
def test_bert_non_optimized(model_name):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, dtype)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
print(f"Output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}")
print(f"Output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}")
assert (sequence_output - sequence_output_ref).abs().max().item() < 3 * (
sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 3 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_optimized(model_name):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation of fused_mlp assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
# If you just want "gelu", disable fused_mlp.
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, dtype)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.bert(input_ids, attention_mask=attention_mask)
sequence_output, pooled_output = out.last_hidden_state, out.pooler_output
out_hf = model_hf.bert(input_ids, attention_mask=attention_mask)
sequence_output_hf, pooled_output_hf = out_hf.last_hidden_state, out_hf.pooler_output
# Need to zero out the padded tokens in the sequence before comparison.
sequence_output_hf[~attention_mask, :] = 0.0
out_ref = model_ref.bert(input_ids, attention_mask=attention_mask)
sequence_output_ref, pooled_output_ref = out_ref.last_hidden_state, out_ref.pooler_output
sequence_output_ref[~attention_mask, :] = 0.0
print(
f"BertModel output max diff: {(sequence_output - sequence_output_ref).abs().max().item()}"
)
print(
f"BertModel output mean diff: {(sequence_output - sequence_output_ref).abs().mean().item()}"
)
print(
f"HF fp16 BertModel max diff: {(sequence_output_hf - sequence_output_ref).abs().max().item()}"
)
print(
f"HF fp16 BertModel mean diff: {(sequence_output_hf - sequence_output_ref).abs().mean().item()}"
)
assert (sequence_output - sequence_output_ref).abs().max().item() < 4 * (
sequence_output_hf - sequence_output_ref
).abs().max().item()
assert (pooled_output - pooled_output_ref).abs().max().item() < 4 * (
pooled_output_hf - pooled_output_ref
).abs().max().item()
out = model(input_ids, attention_mask=attention_mask)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
# Need to zero out the padded tokens in the sequence before comparison.
prediction_scores = prediction_scores.clone()
prediction_scores[~attention_mask, :] = 0.0
out_hf = model_hf(input_ids, attention_mask=attention_mask)
prediction_scores_hf, seq_relationship_scores_hf = (
out_hf.prediction_logits,
out_hf.seq_relationship_logits,
)
prediction_scores_hf[~attention_mask, :] = 0.0
out_ref = model_ref(input_ids, attention_mask=attention_mask)
prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref[~attention_mask, :] = 0.0
print(
f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
)
print(
f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
)
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
@pytest.mark.parametrize("last_layer_subset", [False, True])
# @pytest.mark.parametrize('last_layer_subset', [True])
@pytest.mark.parametrize("has_key_padding_mask", [True, False])
# @pytest.mark.parametrize('has_key_padding_mask', [True])
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subset):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = BertConfig.from_pretrained(model_name)
# Our implementation of fused_mlp assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new", "gelu_fast", or "gelu_pytorch_tanh".
# If you just want "gelu", disable fused_mlp.
config.hidden_act = "gelu_new"
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.dense_seq_output = True
config.last_layer_subset = last_layer_subset
config.use_xentropy = True
model = BertForPreTraining.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = get_hf_models(model_name, config, torch.float32)
model_hf = get_hf_models(model_name, config, dtype)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
if has_key_padding_mask:
attention_mask = torch.arange(max_seqlen, device="cuda")[None, :] < seqlens[:, None]
else:
attention_mask = None
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
labels = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
if attention_mask is not None:
labels[~attention_mask] = 0
labels[(torch.rand(batch_size, max_seqlen, device="cuda") > 0.15)] = 0
masked_tokens_mask = labels.flatten() > 0
next_sequence_label = torch.randint(0, 2, (batch_size,), device="cuda")
out = model(
input_ids,
attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores, seq_relationship_scores = out.prediction_logits, out.seq_relationship_logits
out_hf = model_hf(
input_ids,
attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores_hf, seq_relationship_scores_hf = (
out_hf.prediction_logits,
out_hf.seq_relationship_logits,
)
prediction_scores_hf = rearrange(prediction_scores_hf, "b s d -> (b s) d")[masked_tokens_mask]
out_ref = model_ref(
input_ids,
attention_mask=attention_mask,
labels=labels,
next_sentence_label=next_sequence_label,
)
prediction_scores_ref, seq_relationship_scores_ref = (
out_ref.prediction_logits,
out_ref.seq_relationship_logits,
)
prediction_scores_ref = rearrange(prediction_scores_ref, "b s d -> (b s) d")[masked_tokens_mask]
print(
f"prediction_scores max diff: {(prediction_scores - prediction_scores_ref).abs().max().item()}"
)
print(
f"prediction_scores mean diff: {(prediction_scores - prediction_scores_ref).abs().mean().item()}"
)
print(
f"HF fp16 prediction_scoresff: {(prediction_scores_hf - prediction_scores_ref).abs().max().item()}"
)
print(
f"HF fp16 prediction_scoresiff: {(prediction_scores_hf - prediction_scores_ref).abs().mean().item()}"
)
assert (prediction_scores - prediction_scores_ref).abs().max().item() < 2 * (
prediction_scores_hf - prediction_scores_ref
).abs().max().item()
assert (seq_relationship_scores - seq_relationship_scores_ref).abs().max().item() < 2 * (
seq_relationship_scores_hf - seq_relationship_scores_ref
).abs().max().item()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
@pytest.mark.parametrize("model_name", ["bert-base-uncased", "bert-large-uncased"])
def test_inv_remap_state_dict(model_name: str):
"""
Verify that we can convert a HF BERT model to flash_attn and back.
"""
state_dict = state_dict_from_pretrained(model_name)
config = BertConfig.from_pretrained(model_name)
flash_state_dict = remap_state_dict(state_dict, config)
recovered_state_dict = inv_remap_state_dict(flash_state_dict, config)
assert set(state_dict.keys()) == set(recovered_state_dict.keys())
for k in state_dict.keys():
assert state_dict[k].shape == recovered_state_dict[k].shape
torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)
import time
import pytest
import torch
from transformers import AutoTokenizer, GPTBigCodeConfig
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
from flash_attn.models.bigcode import bigcode_config_to_gpt2_config, inv_remap_state_dict_hf_bigcode
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_bigcode
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
def test_bigcode_state_dict(model_name):
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_bigcode(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta")
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
def test_bigcode_optimized(model_name):
"""Check that our implementation of BigCode (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = GPTBigCodeForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
def test_bigcode_generation(model_name):
"""Check that our implementation of BigCode (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = bigcode_config_to_gpt2_config(GPTBigCodeConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = True
tokenizer = AutoTokenizer.from_pretrained(model_name)
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = GPTBigCodeForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = GPTBigCodeForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
del model
hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert (logits_cg - logits_ref).abs().max().item() < 2 * hf_error
@pytest.mark.parametrize("model_name", ["bigcode/starcoderbase-1b", "WizardLM/WizardCoder-1B-V1.0"])
def test_inv_remap_state_dict(model_name: str):
"""
Verify that we can convert a HF BigCode model to flash_attn and back.
"""
state_dict = state_dict_from_pretrained(model_name)
config = GPTBigCodeConfig.from_pretrained(model_name)
flash_state_dict = remap_state_dict_hf_bigcode(state_dict, config)
recovered_state_dict = inv_remap_state_dict_hf_bigcode(flash_state_dict, config)
assert set(state_dict.keys()) == set(recovered_state_dict.keys())
for k in state_dict.keys():
assert state_dict[k].shape == recovered_state_dict[k].shape
torch.testing.assert_close(state_dict[k], recovered_state_dict[k], rtol=1e-6, atol=1e-6)
# Copyright (c) 2023, Tri Dao.
import time
import torch
import pytest
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.btlm import btlm_config_to_gpt2_config, remap_state_dict_hf_btlm
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.generation import update_graph_cache
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
def test_btlm_state_dict(model_name):
config = btlm_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert len(state_dict.keys()) == len(pretrained_state_dict.keys())
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
def test_btlm_optimized(model_name):
"""Check that our implementation of Btlm (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = btlm_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.fused_bias_fc = True
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map={"": device},
trust_remote_code=True,
)
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
def test_btlm_generation(model_name):
dtype = torch.float16
device = "cuda"
config = btlm_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.fused_bias_fc = True
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 2048
max_length = 2048 + 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1].to(device=device)
del model_ref
pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.load_state_dict(pretrained_state_dict)
model.eval()
model(input_ids) # Warm up
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
del model
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
assert torch.equal(logits_cg, logits)
@pytest.mark.parametrize("model_name", ["cerebras/btlm-3b-8k-base"])
def test_btlm_init(model_name):
dtype = torch.float32
device = "cuda"
btlm_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config = btlm_config_to_gpt2_config(btlm_config)
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model_ref = AutoModelForCausalLM.from_config(btlm_config, trust_remote_code=True).to(device)
assert model.transformer.embeddings.word_embeddings.weight.mean().abs() < 1e-4
assert (
model.transformer.embeddings.word_embeddings.weight.std()
- model_ref.transformer.wte.weight.std()
).abs() < 1e-4
assert model.lm_head.weight.mean().abs() < 1e-4
assert (model.lm_head.weight.std() - model_ref.lm_head.weight.std()).abs() < 1e-4
for l in range(config.n_layer):
assert model.transformer.layers[l].mixer.Wqkv.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mixer.Wqkv.weight.std()
- model_ref.transformer.h[l].attn.c_attn.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mixer.Wqkv.bias.abs().max() == 0.0
assert model.transformer.layers[l].mixer.out_proj.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mixer.out_proj.weight.std()
- model_ref.transformer.h[l].attn.c_proj.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mixer.out_proj.bias.abs().max() == 0.0
assert model.transformer.layers[l].mlp.fc1.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mlp.fc1.weight.std()
- model_ref.transformer.h[l].mlp.c_fc.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mlp.fc1.bias.abs().max() == 0.0
assert model.transformer.layers[l].mlp.fc2.weight.mean().abs() < 1e-4
assert (
model.transformer.layers[l].mlp.fc2.weight.std()
- model_ref.transformer.h[l].mlp.c_proj.weight.std()
).abs() < 1e-4
assert model.transformer.layers[l].mlp.fc2.bias.abs().max() == 0.0
# Copyright (c) 2023, Tri Dao.
import os
import time
from pathlib import Path
current_dir = Path(__file__).parent.absolute()
import pytest
import torch
from einops import rearrange
from flash_attn.models.falcon import falcon_config_to_gpt2_config, remap_state_dict_hf_falcon
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b", "tiiuae/falcon-40b"])
def test_falcon_state_dict(model_name):
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
pretrained_state_dict = remap_state_dict_hf_falcon(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
def test_falcon_optimized(model_name):
"""Check that our implementation (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map={"": device}, trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_forward"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
def test_falcon_parallel_forward(model_name, world_size):
from apex.transformer import parallel_state
dtype = torch.float16
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = False
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
config.fused_dropout_add_ln = False
config.residual_in_fp32 = True
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
pretrained_state_dict = remap_state_dict_hf_falcon(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group)
out = rearrange(out, "(b s) d -> b s d", b=batch_size)
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, "(n b) ... d -> b ... (n d)", b=batch_size)
del model
parallel_state.destroy_model_parallel()
if rank == 0:
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state.to(device=device)
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-7b"])
def test_falcon_generation(model_name):
"""Check that our implementation (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
tokenizer = AutoTokenizer.from_pretrained(model_name)
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}, trust_remote_code=True
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map={"": device}, trust_remote_code=True
)
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
del model
hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits)
# torchrun --no_python --nproc_per_node=4 pytest -q -s tests/models/test_falcon.py -k "falcon_parallel_generation"
# We want to run this on a machine with 4 x A100 80GB or 8 x A100 40GB so we have enough
# memory to run the model in fp32.
@pytest.mark.parametrize("world_size", [4])
@pytest.mark.parametrize("model_name", ["tiiuae/falcon-40b"])
def test_falcon_parallel_generation(model_name, world_size):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
from apex.transformer import parallel_state
dtype = torch.float16
config = falcon_config_to_gpt2_config(
AutoConfig.from_pretrained(model_name, trust_remote_code=True)
)
config.use_flash_attn = False
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused MLP for "gelu" activation
config.fused_dropout_add_ln = False
config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
pretrained_state_dict = remap_state_dict_hf_falcon(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
print("Without CUDA graph")
out = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
cg=True,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
del model
parallel_state.destroy_model_parallel()
if rank == 0:
model_hf = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map="auto", trust_remote_code=True
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
with torch.inference_mode():
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = AutoModelForCausalLM.from_pretrained(
model_name, device_map="auto", trust_remote_code=True
)
model_ref.eval()
with torch.inference_mode():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
hf_error = (logits_hf - logits_ref).abs().max().item()
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits)
import re
import pytest
import torch
from einops import rearrange
from flash_attn.models.gpt import (
GPTLMHeadModel,
remap_state_dict_hf_gpt2,
shard_state_dict_tp,
combine_state_dicts_tp,
)
from flash_attn.utils.generation import InferenceParams
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_state_dict(model_name):
config = GPT2Config.from_pretrained(model_name)
pretrained_state_dict = remap_state_dict_hf_gpt2(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config)
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_non_optimized(model_name):
"""Check that our implementation of GPT2 (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = GPT2Config.from_pretrained(model_name)
model = GPTLMHeadModel.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize("model_name", ["gpt2", "gpt2-medium"])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def test_gpt2_optimized(model_name):
"""Check that our implementation of GPT2 (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
config = GPT2Config.from_pretrained(model_name)
vocab_size_og = config.vocab_size
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
config.pad_vocab_size_multiple = 8
model = GPTLMHeadModel.from_pretrained(model_name, config)
model = model.cuda().to(dtype=dtype)
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).cuda()
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).cuda().to(dtype=dtype)
model.eval()
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
batch_size = 4
max_seqlen = 512
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device="cuda")
input_ids = torch.randint(
0, vocab_size_og, (batch_size, max_seqlen), dtype=torch.long, device="cuda"
)
out = model.transformer(input_ids)
out_hf = model_hf.transformer(input_ids).last_hidden_state
out_ref = model_ref.transformer(input_ids).last_hidden_state
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
logits = model(input_ids).logits[..., :vocab_size_og]
logits_hf = model_hf(input_ids).logits
logits_ref = model_ref(input_ids).logits
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize('optimized', [True])
@pytest.mark.parametrize("rotary", [False, True])
# @pytest.mark.parametrize('rotary', [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation(model_name, rotary, optimized):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_fraction = 0.5
config.rotary_emb_base = 24000
config.residual_in_fp32 = True
if optimized:
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(
model_name, config, strict=not rotary, device=device, dtype=dtype
)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name, torch_dtype=dtype).to(
device=device
)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
device=device
)
max_length = 25
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
scores.append(model(cur_input_ids).logits[:, -1])
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist()))
if getattr(config, "use_flash_attn", False):
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
print(out_cg.sequences)
assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
if not rotary:
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
out_ref = model_ref.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(tokenizer.batch_decode(out_ref.sequences.tolist()))
assert torch.all(out.sequences == sequences)
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
out = model.generate(
input_ids=input_ids,
max_length=max_length,
teacher_outputs=teacher_outputs,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
**kwargs,
)
return torch.stack(out.scores, dim=1)
@pytest.mark.parametrize("seqlen,maxlen", [(10, 20), (30, 150), (3000, 3400), (14000, 15000)])
# @pytest.mark.parametrize('seqlen,maxlen', [(10, 20)])
@pytest.mark.parametrize("rotary", [None, "interleaved", "contiguous"])
# @pytest.mark.parametrize('rotary', [None])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_generation_cg(model_name, rotary, seqlen, maxlen):
"""Check that decoding with CUDA graph is the same as decoding without CUDA graph."""
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
config.n_positions = 16 * 1024
assert seqlen <= maxlen <= config.n_positions
if rotary is not None:
config.n_positions = 0
config.rotary_emb_dim = 32
config.rotary_emb_interleaved = rotary == "interleaved"
config.residual_in_fp32 = True
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
model = GPTLMHeadModel(config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
batch_size = 1
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)
# Try increasing batch size and seqlen, then decrease them to see if it's still correct
batch_size = 3
maxlen += 30
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)
batch_size = 2
maxlen -= 35
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
teacher_outputs = torch.randint(
0, config.vocab_size, (batch_size, maxlen), dtype=torch.long, device=device
)
logits = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs)
logits_cg = get_logits(model, input_ids, maxlen, teacher_outputs=teacher_outputs, cg=True)
assert torch.equal(logits, logits_cg)
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize("optimized", [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_gpt2_multiple_token_generation(model_name, optimized):
"""Generation when we pass in multiple tokens at a time, not just one."""
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
config.residual_in_fp32 = True
if optimized:
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
input_ids = torch.randint(0, config.vocab_size, (1, 20), dtype=torch.long, device=device)
# Reference logits
logits_ref = model(input_ids).logits
# Run 10 tokens, then pass in another 4, then another 6, to see if we get the same logits
inference_params = InferenceParams(max_seqlen=20, max_batch_size=1)
logits_10 = model(input_ids[:, :10], inference_params=inference_params).logits
inference_params.seqlen_offset += 10
position_ids = torch.arange(10, 14, dtype=torch.long, device=device)
logits_1014 = model(
input_ids[:, 10:14], position_ids=position_ids, inference_params=inference_params
).logits
inference_params.seqlen_offset += 4
position_ids = torch.arange(14, 20, dtype=torch.long, device=device)
logits_1420 = model(
input_ids[:, 14:20], position_ids=position_ids, inference_params=inference_params
).logits
logits = torch.cat([logits_10, logits_1014, logits_1420], dim=1)
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
assert torch.allclose(logits, logits_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("cg", [False, True])
# @pytest.mark.parametrize("cg", [True])
@pytest.mark.parametrize("optimized", [False, True])
# @pytest.mark.parametrize("optimized", [True])
# @pytest.mark.parametrize("model_name", ["gpt2-medium"])
@pytest.mark.parametrize("model_name", ["gpt2-xl"])
def test_gpt2_speculative_decoding(model_name, optimized, cg):
if cg and not optimized:
pytest.skip() # CG requires use_flash_attn
dtype = torch.float16
device = "cuda"
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
config.residual_in_fp32 = True
if optimized:
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config_draft = GPT2Config.from_pretrained("gpt2")
config_draft.residual_in_fp32 = True
if optimized:
config_draft.use_flash_attn = True
config_draft.fused_bias_fc = True
config_draft.fused_mlp = True
config_draft.fused_dropout_add_ln = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
model_draft = GPTLMHeadModel.from_pretrained("gpt2", config_draft, device=device, dtype=dtype)
model_draft.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and he", return_tensors="pt").input_ids.to(
device=device
)
max_length = 100
from flash_attn.utils.generation import decode_speculative
torch.manual_seed(42)
print(f"Speculative decoding, {optimized = }")
out = decode_speculative(
input_ids,
model,
model_draft,
max_length=max_length,
top_k=5,
cg=cg,
speculative_lookahead=4,
enable_timing=True,
# debug=True,
)
print(tokenizer.batch_decode(out.sequences))
print(f"Without speculative decoding, {cg = }")
out_og = model.generate(
input_ids,
max_length=max_length,
top_k=5,
cg=cg,
enable_timing=True,
return_dict_in_generate=True,
)
print(tokenizer.batch_decode(out_og.sequences))
@pytest.mark.parametrize(
"n_heads_q_kv",
[
(8, 8), # Regular attention
(8, 4), # GQA
(8, 2), # MQA
],
)
def test_gpt2_shard_unshard(n_heads_q_kv):
world_size = 2
config = GPT2Config.from_pretrained("gpt2")
config.vocab_size = 1024
config.n_head, config.n_head_kv = n_heads_q_kv
model = GPTLMHeadModel(config, device="cuda", dtype=torch.float16)
state_dict = model.state_dict()
shards = [
# NOTE: Shallow copy as `state_dict` is modified in-place
shard_state_dict_tp(dict(state_dict), config, world_size, rank)
for rank in range(world_size)
]
state_dict2 = combine_state_dicts_tp(shards, config)
assert state_dict2.keys() == state_dict.keys()
for k in state_dict.keys():
ref = state_dict[k]
new = state_dict[k]
assert torch.allclose(ref, new, atol=0.0, rtol=0.0)
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_generation_parallel.py -k "parallel"
import os
import re
import pytest
import torch
from einops import rearrange
from flash_attn.models.gpt import GPTLMHeadModel, remap_state_dict_hf_gpt2
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config, GPT2Tokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel as GPT2LMHeadModelHF
# @pytest.mark.parametrize('world_size', [1, 2, 4, 8])
@pytest.mark.parametrize("world_size", [2])
@pytest.mark.parametrize('rotary', [False, True])
# @pytest.mark.parametrize("rotary", [False])
@pytest.mark.parametrize("model_name", ["gpt2"])
def test_tensor_parallel(model_name, rotary, world_size):
"""Check that our implementation of GPT2 generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
dtype = torch.float16
rtol, atol = 3e-3, 3e-1
config = GPT2Config.from_pretrained(model_name)
if rotary:
config.n_positions = 0
config.rotary_emb_dim = 64
config.residual_in_fp32 = True
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.pad_vocab_size_multiple = 8 * world_size
config.sequence_parallel = False # Need to set this to False for generation
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0"
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch.cuda.set_device(device)
from apex.transformer import parallel_state
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# if not rotary, we load the weight from HF but ignore the position embeddings.
# The model would be nonsense but it doesn't matter for the test.
model = GPTLMHeadModel.from_pretrained(
model_name,
config,
strict=not rotary,
device=device,
dtype=dtype,
process_group=process_group,
world_size=world_size,
rank=rank,
)
model.eval()
if not rotary:
model_ref = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device)
model_hf = GPT2LMHeadModelHF.from_pretrained(model_name).to(device=device, dtype=dtype)
model_ref.eval()
model_hf.eval()
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and ", return_tensors="pt").input_ids.to(
device=device
)
max_length = 30
# input_ids = torch.randint(0, 100, (1, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# Slow generation for reference
sequences = []
scores = []
cur_input_ids = input_ids
with torch.inference_mode():
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
..., : config.vocab_size
]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
for _ in range(input_ids.shape[1] + 1, max_length):
cur_input_ids = torch.cat([cur_input_ids, rearrange(sequences[-1], "b -> b 1")], dim=-1)
logits, _ = all_gather_raw(model(cur_input_ids).logits[:, -1], process_group)
logits = rearrange(logits, "(n b) d -> b (n d)", b=input_ids.shape[0])[
..., : config.vocab_size
]
scores.append(logits)
sequences.append(scores[-1].argmax(dim=-1))
sequences = torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1)
scores = tuple(scores)
print(sequences)
out = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
print(out.sequences)
if getattr(config, "use_flash_attn", False):
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
tensor_parallel=world_size,
vocab_size=config.vocab_size,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
)
print(out_cg.sequences)
parallel_state.destroy_model_parallel()
if not rotary:
out_hf = model_hf.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
out_ref = model_ref.generate(
input_ids=input_ids,
max_length=max_length,
return_dict_in_generate=True,
output_scores=True,
)
print(
f"Scores max diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"Scores mean diff: {(torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
print(
f"HF fp16 max diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().max().item()}"
)
print(
f"HF fp16 mean diff: {(torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)).abs().mean().item()}"
)
assert torch.all(out.sequences == sequences)
assert torch.allclose(
torch.stack(out.scores, dim=1), torch.stack(scores, dim=1), rtol=rtol, atol=atol
)
assert torch.equal(torch.stack(out.scores, dim=1), torch.stack(out_cg.scores, dim=1))
if not rotary:
assert torch.all(out.sequences == out_ref.sequences)
assert torch.all(out.sequences == out_hf.sequences)
assert (
torch.stack(out.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item() < 3 * (
torch.stack(out_hf.scores, 1) - torch.stack(out_ref.scores, 1)
).abs().max().item()
# Copyright (c) 2023, Tri Dao.
import time
import pytest
import torch
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gpt_neox import gpt_neox_config_to_gpt2_config, remap_state_dict_hf_gpt_neox
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTNeoXConfig
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-neox-20b"])
def test_gptj_state_dict(model_name):
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gpt_neox(
state_dict_from_pretrained(model_name), config
)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize(
"model_name",
[
"EleutherAI/pythia-1b",
"EleutherAI/pythia-2.8b",
"EleutherAI/gpt-neox-20b",
"togethercomputer/RedPajama-INCITE-7B-Base",
],
)
def test_gpt_neox_optimized(model_name):
"""Check that our implementation of GPT-NeoX (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(model_name))
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = config.activation_function in [
"gelu_fast",
"gelu_new",
"gelu_approx",
"gelu_pytorch_tanh",
]
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Need at least 2 GPUs, otherwise we'll OOM for the 20B model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTNeoXForCausalLM.from_pretrained(model_name, device_map="auto")
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.gpt_neox(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = GPTNeoXForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
with torch.no_grad():
out_hf = model_hf.gpt_neox(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
assert (out - out_ref).abs().mean().item() < 2 * (out_hf - out_ref).abs().mean().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 2 * (
logits_hf - logits_ref
).abs().max().item()
assert (logits - logits_ref).abs().mean().item() < 2 * (
logits_hf - logits_ref
).abs().mean().item()
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/models/test_gpt_parallel.py
import math
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.transformer import parallel_state
from einops import rearrange
from flash_attn.losses.cross_entropy import CrossEntropyLoss
from flash_attn.models.gpt import GPTLMHeadModel, shard_state_dict_tp
from flash_attn.utils.distributed import allreduce_sequence_parallel_grad
from transformers import GPT2Config
is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
@pytest.mark.parametrize("dtype", [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize("world_size", [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize("sequence_parallel", [True, False])
# @pytest.mark.parametrize('sequence_parallel', [False])
@pytest.mark.parametrize("has_pos_emb", [True, False])
# @pytest.mark.parametrize('has_pos_emb', [True])
@pytest.mark.parametrize("dim", [1024])
def test_gpt_parallel(dim, has_pos_emb, sequence_parallel, world_size, dtype):
head_dim = 64
assert dim % head_dim == 0
num_heads = dim // head_dim
assert num_heads % world_size == 0
vocab_size = 50264
assert vocab_size % world_size == 0
num_layers = 2
rtol, atol = (3e-3, 1e-1) if dtype == torch.bfloat16 else (3e-3, 1e-2)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl", init_method="env://")
device = f"cuda:{torch.distributed.get_rank()}"
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
input_ids = torch.randint(0, vocab_size, (batch_size, seqlen + 1), device=device)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
g = torch.randn(batch_size * seqlen, device=device)
config = GPT2Config(
n_embd=dim,
n_head=num_heads,
n_layer=num_layers,
n_positions=seqlen if has_pos_emb else 0,
vocab_size=50257,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
scale_attn_by_inverse_layer_idx=True,
use_flash_attn=True,
fused_mlp=True,
fused_bias_fc=True,
fused_dropout_add_ln=True,
residual_in_fp32=True,
rotary_emb_fraction=0.0 if has_pos_emb else 0.5,
pad_vocab_size_multiple=8 * world_size,
sequence_parallel=sequence_parallel,
)
config.vocab_size = math.ceil(config.vocab_size / (8 * world_size)) * (8 * world_size)
model_pt = GPTLMHeadModel(config, device=device)
def init_layer_norm(module):
if isinstance(module, nn.LayerNorm):
nn.init.normal_(module.weight)
nn.init.normal_(module.bias)
model_pt.apply(init_layer_norm)
model = GPTLMHeadModel(config, process_group=process_group, device=device)
total_nparams = sum(p.numel() for p in model_pt.parameters())
sharded_nparams = sum(p.numel() for p in model.parameters())
sharded_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor(
sharded_nparams_all, torch.tensor([sharded_nparams], device=device), group=process_group
)
shared_nparams = sum(
p.numel() for p in model.parameters() if getattr(p, "_shared_params", False)
)
shared_nparams_all = torch.empty(world_size, dtype=torch.long, device=device)
torch.distributed.all_gather_into_tensor(
shared_nparams_all, torch.tensor([shared_nparams], device=device), group=process_group
)
assert torch.all(shared_nparams_all == shared_nparams)
assert total_nparams == (
(sharded_nparams_all - shared_nparams_all).sum().item() + shared_nparams
)
# vocab_size has been rounded up here
partition_vocab_size = config.vocab_size // world_size
partition_dim = dim // world_size
partition_hidden_dim = 4 * dim // world_size
with torch.no_grad():
model.load_state_dict(shard_state_dict_tp(model_pt.state_dict(), config, world_size, rank))
model.tie_weights()
with torch.autocast(device_type="cuda", dtype=dtype):
out = model(input_ids[:, :-1]).logits
if not sequence_parallel:
out = rearrange(out, "b s d -> (b s) d")
out_pt = rearrange(model_pt(input_ids[:, :-1]).logits, "b s d -> (b s) d")
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out,
out_pt[:, rank * partition_vocab_size : (rank + 1) * partition_vocab_size],
rtol=rtol,
atol=atol,
)
loss_fn = CrossEntropyLoss(inplace_backward=True, reduction="none", process_group=process_group)
loss_fn_pt = CrossEntropyLoss(inplace_backward=True, reduction="none")
loss = loss_fn(out, input_ids[:, 1:].flatten())
loss_pt = loss_fn_pt(out_pt, input_ids[:, 1:].flatten())
assert torch.allclose(loss, loss_pt, rtol=rtol, atol=atol)
loss_pt.backward(g)
loss.backward(g)
allreduce_sequence_parallel_grad(model, process_group)
parallel_state.destroy_model_parallel()
grad_dict = shard_state_dict_tp(
{k: v.grad for k, v in model_pt.named_parameters()}, config, world_size, rank
)
assert torch.allclose(
model.transformer.embeddings.word_embeddings.weight.grad,
grad_dict["transformer.embeddings.word_embeddings.weight"],
rtol=rtol,
atol=atol * 5,
)
if has_pos_emb:
assert torch.allclose(
model.transformer.embeddings.position_embeddings.weight.grad,
grad_dict["transformer.embeddings.position_embeddings.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.ln_f.weight.grad,
grad_dict["transformer.ln_f.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.ln_f.bias.grad, grad_dict["transformer.ln_f.bias"], rtol=rtol, atol=atol
)
for i in range(num_layers):
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.weight.grad,
grad_dict[f"transformer.layers.{i}.mixer.Wqkv.weight"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mixer.Wqkv.bias.grad,
grad_dict[f"transformer.layers.{i}.mixer.Wqkv.bias"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.weight.grad,
grad_dict[f"transformer.layers.{i}.mixer.out_proj.weight"],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(
model.transformer.layers[i].mixer.out_proj.bias.grad,
grad_dict[f"transformer.layers.{i}.mixer.out_proj.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.weight.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc1.weight"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc1.bias.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc1.bias"],
rtol=rtol,
atol=atol * 10,
)
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.weight.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc2.weight"],
rtol=rtol,
atol=atol * 10,
)
if rank == 0:
assert torch.allclose(
model.transformer.layers[i].mlp.fc2.bias.grad,
grad_dict[f"transformer.layers.{i}.mlp.fc2.bias"],
rtol=rtol,
atol=atol * 5,
)
assert torch.allclose(
model.transformer.layers[i].norm1.weight.grad,
grad_dict[f"transformer.layers.{i}.norm1.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm1.bias.grad,
grad_dict[f"transformer.layers.{i}.norm1.bias"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.weight.grad,
grad_dict[f"transformer.layers.{i}.norm2.weight"],
rtol=rtol,
atol=atol,
)
assert torch.allclose(
model.transformer.layers[i].norm2.bias.grad,
grad_dict[f"transformer.layers.{i}.norm2.bias"],
rtol=rtol,
atol=atol,
)
# Copyright (c) 2023, Tri Dao.
import time
import pytest
import torch
from flash_attn.models.gpt import GPTLMHeadModel
from flash_attn.models.gptj import gptj_config_to_gpt2_config, remap_state_dict_hf_gptj
from flash_attn.utils.generation import update_graph_cache
from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import AutoTokenizer, GPTJConfig
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_state_dict(model_name):
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
pretrained_state_dict = remap_state_dict_hf_gptj(state_dict_from_pretrained(model_name), config)
model = GPTLMHeadModel(config, device="meta") # Without device='meta' init is very slow
state_dict = model.state_dict()
assert state_dict.keys() == pretrained_state_dict.keys()
for k in state_dict.keys():
assert state_dict[k].shape == pretrained_state_dict[k].shape
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B", "togethercomputer/GPT-JT-6B-v1"])
def test_gptj_optimized(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
input_ids = torch.randint(
0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device
)
with torch.no_grad():
out = model.transformer(input_ids)
logits = model(input_ids).logits
del model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
with torch.no_grad():
out_ref = model_ref.transformer(input_ids).last_hidden_state
logits_ref = model_ref(input_ids).logits
del model_ref
model_hf = GPTJForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
out_hf = model_hf.transformer(input_ids).last_hidden_state
logits_hf = model_hf(input_ids).logits
del model_hf
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}")
assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()
print(f"Logits max diff: {(logits - logits_ref).abs().max().item()}")
print(f"Logits mean diff: {(logits - logits_ref).abs().mean().item()}")
print(f"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}")
print(f"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}")
assert (logits - logits_ref).abs().max().item() < 3 * (
logits_hf - logits_ref
).abs().max().item()
@pytest.mark.parametrize("model_name", ["EleutherAI/gpt-j-6B"])
def test_gptj_generation(model_name):
"""Check that our implementation of GPT-J (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype = torch.float16
device = "cuda"
config = gptj_config_to_gpt2_config(GPTJConfig.from_pretrained(model_name))
config.use_flash_attn = True # FlashAttention-2 supports headdim 256
config.fused_bias_fc = True
config.fused_mlp = True
config.fused_dropout_add_ln = True
# Only prenorm supports residual_in_fp32
config.residual_in_fp32 = True
tokenizer = AutoTokenizer.from_pretrained(model_name)
eos_token_id = tokenizer.eos_token_id
torch.manual_seed(0)
batch_size = 1
seqlen = 100
max_length = 150
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seqlen), dtype=torch.long, device=device
)
model_hf = GPTJForCausalLM.from_pretrained(
model_name, torch_dtype=dtype, device_map={"": device}
)
model_hf.eval()
print("HF fp16")
torch.cuda.synchronize()
start = time.time()
out_hf = model_hf.generate(
input_ids=input_ids, max_length=max_length, return_dict_in_generate=True, output_scores=True
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
del model_hf
model_ref = GPTJForCausalLM.from_pretrained(model_name, device_map={"": device})
model_ref.eval()
with torch.no_grad():
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1) : -1]
del model_ref
model = GPTLMHeadModel.from_pretrained(model_name, config, device=device, dtype=dtype)
model.eval()
print("Without CUDA graph")
torch.cuda.synchronize()
start = time.time()
out = model.generate(
input_ids=input_ids,
max_length=max_length,
eos_token_id=eos_token_id,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
# Capture graph outside the timing loop
batch_size, seqlen_og = input_ids.shape
model._decoding_cache = update_graph_cache(model, None, batch_size, seqlen_og, max_length)
print("With CUDA graph")
torch.cuda.synchronize()
start = time.time()
out_cg = model.generate(
input_ids=input_ids,
max_length=max_length,
cg=True,
return_dict_in_generate=True,
output_scores=True,
enable_timing=True,
teacher_outputs=out_hf.sequences,
)
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
with torch.no_grad():
logits_parallel = model(out_hf.sequences).logits[:, (seqlen - 1) : -1]
logits_hf = torch.stack(out_hf.scores, dim=1)
logits = torch.stack(out.scores, dim=1)
logits_cg = torch.stack(out_cg.scores, dim=1)
del model
hf_error = (logits_hf - logits_ref).abs().max().item()
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
print(f"HF fp16 logits max diff: {hf_error}")
print(f"Logits max diff: {(logits - logits_ref).abs().max().item() }")
assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f"Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }")
assert torch.equal(logits_cg, logits)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment