diff --git a/.gitignore b/.gitignore index a143965a0238d469c21f30ce911ee6d18112abc7..e8bb8c8570b5e231720120f2f8600792a13b86e8 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ __pycache__/ *.egg-info/ workspace/ +.cache +*build*/ diff --git a/3rdparty/INIReader.h b/3rdparty/INIReader.h new file mode 100644 index 0000000000000000000000000000000000000000..7d40f0638f2ba88342035b8e33c45a9029320d84 --- /dev/null +++ b/3rdparty/INIReader.h @@ -0,0 +1,501 @@ +// Read an INI file into easy-to-access name/value pairs. + +// inih and INIReader are released under the New BSD license. +// Go to the project home page for more info: +// +// https://github.com/benhoyt/inih (Initial repo) +// https://github.com/jtilly/inih (The reference of this header file) +/* inih -- simple .INI file parser +inih is released under the New BSD license (see LICENSE.txt). Go to the project +home page for more info: +https://github.com/benhoyt/inih +https://github.com/jtilly/inih +*/ + +#ifndef __INI_H__ +#define __INI_H__ + +/* Make this header file easier to include in C++ code */ +#ifdef __cplusplus +extern "C" { +#endif + +#include + +/* Typedef for prototype of handler function. */ +typedef int (*ini_handler)(void* user, const char* section, + const char* name, const char* value); + +/* Typedef for prototype of fgets-style reader function. */ +typedef char* (*ini_reader)(char* str, int num, void* stream); + +/* Parse given INI-style file. May have [section]s, name=value pairs + (whitespace stripped), and comments starting with ';' (semicolon). Section + is "" if name=value pair parsed before any section heading. name:value + pairs are also supported as a concession to Python's configparser. + For each name=value pair parsed, call handler function with given user + pointer as well as section, name, and value (data only valid for duration + of handler call). Handler should return nonzero on success, zero on error. + Returns 0 on success, line number of first error on parse error (doesn't + stop on first error), -1 on file open error, or -2 on memory allocation + error (only when INI_USE_STACK is zero). +*/ +int ini_parse(const char* filename, ini_handler handler, void* user); + +/* Same as ini_parse(), but takes a FILE* instead of filename. This doesn't + close the file when it's finished -- the caller must do that. */ +int ini_parse_file(FILE* file, ini_handler handler, void* user); + +/* Same as ini_parse(), but takes an ini_reader function pointer instead of + filename. Used for implementing custom or string-based I/O. */ +int ini_parse_stream(ini_reader reader, void* stream, ini_handler handler, + void* user); + +/* Nonzero to allow multi-line value parsing, in the style of Python's + configparser. If allowed, ini_parse() will call the handler with the same + name for each subsequent line parsed. */ +#ifndef INI_ALLOW_MULTILINE +#define INI_ALLOW_MULTILINE 1 +#endif + +/* Nonzero to allow a UTF-8 BOM sequence (0xEF 0xBB 0xBF) at the start of + the file. See http://code.google.com/p/inih/issues/detail?id=21 */ +#ifndef INI_ALLOW_BOM +#define INI_ALLOW_BOM 1 +#endif + +/* Nonzero to allow inline comments (with valid inline comment characters + specified by INI_INLINE_COMMENT_PREFIXES). Set to 0 to turn off and match + Python 3.2+ configparser behaviour. */ +#ifndef INI_ALLOW_INLINE_COMMENTS +#define INI_ALLOW_INLINE_COMMENTS 1 +#endif +#ifndef INI_INLINE_COMMENT_PREFIXES +#define INI_INLINE_COMMENT_PREFIXES ";" +#endif + +/* Nonzero to use stack, zero to use heap (malloc/free). */ +#ifndef INI_USE_STACK +#define INI_USE_STACK 1 +#endif + +/* Stop parsing on first error (default is to keep parsing). */ +#ifndef INI_STOP_ON_FIRST_ERROR +#define INI_STOP_ON_FIRST_ERROR 0 +#endif + +/* Maximum line length for any line in INI file. */ +#ifndef INI_MAX_LINE +#define INI_MAX_LINE 200 +#endif + +#ifdef __cplusplus +} +#endif + +/* inih -- simple .INI file parser +inih is released under the New BSD license (see LICENSE.txt). Go to the project +home page for more info: +https://github.com/benhoyt/inih +*/ + +#if defined(_MSC_VER) && !defined(_CRT_SECURE_NO_WARNINGS) +#define _CRT_SECURE_NO_WARNINGS +#endif + +#include +#include +#include + +#if !INI_USE_STACK +#include +#endif + +#define MAX_SECTION 50 +#define MAX_NAME 50 + +/* Strip whitespace chars off end of given string, in place. Return s. */ +inline static char* rstrip(char* s) +{ + char* p = s + strlen(s); + while (p > s && isspace((unsigned char)(*--p))) + *p = '\0'; + return s; +} + +/* Return pointer to first non-whitespace char in given string. */ +inline static char* lskip(const char* s) +{ + while (*s && isspace((unsigned char)(*s))) + s++; + return (char*)s; +} + +/* Return pointer to first char (of chars) or inline comment in given string, + or pointer to null at end of string if neither found. Inline comment must + be prefixed by a whitespace character to register as a comment. */ +inline static char* find_chars_or_comment(const char* s, const char* chars) +{ +#if INI_ALLOW_INLINE_COMMENTS + int was_space = 0; + while (*s && (!chars || !strchr(chars, *s)) && + !(was_space && strchr(INI_INLINE_COMMENT_PREFIXES, *s))) { + was_space = isspace((unsigned char)(*s)); + s++; + } +#else + while (*s && (!chars || !strchr(chars, *s))) { + s++; + } +#endif + return (char*)s; +} + +/* Version of strncpy that ensures dest (size bytes) is null-terminated. */ +inline static char* strncpy0(char* dest, const char* src, size_t size) +{ + strncpy(dest, src, size); + dest[size - 1] = '\0'; + return dest; +} + +/* See documentation in header file. */ +inline int ini_parse_stream(ini_reader reader, void* stream, ini_handler handler, + void* user) +{ + /* Uses a fair bit of stack (use heap instead if you need to) */ +#if INI_USE_STACK + char line[INI_MAX_LINE]; +#else + char* line; +#endif + char section[MAX_SECTION] = ""; + char prev_name[MAX_NAME] = ""; + + char* start; + char* end; + char* name; + char* value; + int lineno = 0; + int error = 0; + +#if !INI_USE_STACK + line = (char*)malloc(INI_MAX_LINE); + if (!line) { + return -2; + } +#endif + + /* Scan through stream line by line */ + while (reader(line, INI_MAX_LINE, stream) != NULL) { + lineno++; + + start = line; +#if INI_ALLOW_BOM + if (lineno == 1 && (unsigned char)start[0] == 0xEF && + (unsigned char)start[1] == 0xBB && + (unsigned char)start[2] == 0xBF) { + start += 3; + } +#endif + start = lskip(rstrip(start)); + + if (*start == ';' || *start == '#') { + /* Per Python configparser, allow both ; and # comments at the + start of a line */ + } +#if INI_ALLOW_MULTILINE + else if (*prev_name && *start && start > line) { + +#if INI_ALLOW_INLINE_COMMENTS + end = find_chars_or_comment(start, NULL); + if (*end) + *end = '\0'; + rstrip(start); +#endif + + /* Non-blank line with leading whitespace, treat as continuation + of previous name's value (as per Python configparser). */ + if (!handler(user, section, prev_name, start) && !error) + error = lineno; + } +#endif + else if (*start == '[') { + /* A "[section]" line */ + end = find_chars_or_comment(start + 1, "]"); + if (*end == ']') { + *end = '\0'; + strncpy0(section, start + 1, sizeof(section)); + *prev_name = '\0'; + } + else if (!error) { + /* No ']' found on section line */ + error = lineno; + } + } + else if (*start) { + /* Not a comment, must be a name[=:]value pair */ + end = find_chars_or_comment(start, "=:"); + if (*end == '=' || *end == ':') { + *end = '\0'; + name = rstrip(start); + value = lskip(end + 1); +#if INI_ALLOW_INLINE_COMMENTS + end = find_chars_or_comment(value, NULL); + if (*end) + *end = '\0'; +#endif + rstrip(value); + + /* Valid name[=:]value pair found, call handler */ + strncpy0(prev_name, name, sizeof(prev_name)); + if (!handler(user, section, name, value) && !error) + error = lineno; + } + else if (!error) { + /* No '=' or ':' found on name[=:]value line */ + error = lineno; + } + } + +#if INI_STOP_ON_FIRST_ERROR + if (error) + break; +#endif + } + +#if !INI_USE_STACK + free(line); +#endif + + return error; +} + +/* See documentation in header file. */ +inline int ini_parse_file(FILE* file, ini_handler handler, void* user) +{ + return ini_parse_stream((ini_reader)fgets, file, handler, user); +} + +/* See documentation in header file. */ +inline int ini_parse(const char* filename, ini_handler handler, void* user) +{ + FILE* file; + int error; + + file = fopen(filename, "r"); + if (!file) + return -1; + error = ini_parse_file(file, handler, user); + fclose(file); + return error; +} + +#endif /* __INI_H__ */ + + +#ifndef __INIREADER_H__ +#define __INIREADER_H__ + +#include +#include +#include + +// Read an INI file into easy-to-access name/value pairs. (Note that I've gone +// for simplicity here rather than speed, but it should be pretty decent.) +class INIReader +{ +public: + // Empty Constructor + INIReader() {}; + + // Construct INIReader and parse given filename. See ini.h for more info + // about the parsing. + INIReader(std::string filename); + + // Construct INIReader and parse given file. See ini.h for more info + // about the parsing. + INIReader(FILE *file); + ~INIReader(); + // Return the result of ini_parse(), i.e., 0 on success, line number of + // first error on parse error, or -1 on file open error. + int ParseError() const; + + // Return the list of sections found in ini file + const std::set& Sections() const; + + // Get a string value from INI file, returning default_value if not found. + std::string Get(std::string section, std::string name, + std::string default_value) const; + std::string Get(std::string section, std::string name) const; + + // Get an integer (long) value from INI file, returning default_value if + // not found or not a valid integer (decimal "1234", "-1234", or hex "0x4d2"). + long GetInteger(std::string section, std::string name, long default_value) const; + long GetInteger(std::string section, std::string name) const; + + // Get a real (floating point double) value from INI file, returning + // default_value if not found or not a valid floating point value + // according to strtod(). + double GetReal(std::string section, std::string name, double default_value) const; + + // Get a single precision floating point number value from INI file, returning + // default_value if not found or not a valid floating point value + // according to strtof(). + float GetFloat(std::string section, std::string name, float default_value) const; + float GetFloat(std::string section, std::string name) const; + + // Get a boolean value from INI file, returning default_value if not found or if + // not a valid true/false value. Valid true values are "true", "yes", "on", "1", + // and valid false values are "false", "no", "off", "0" (not case sensitive). + bool GetBoolean(std::string section, std::string name, bool default_value) const; + +protected: + int _error; + std::map _values; + std::set _sections; + static std::string MakeKey(std::string section, std::string name); + static int ValueHandler(void* user, const char* section, const char* name, + const char* value); +}; + +#endif // __INIREADER_H__ + + +#ifndef __INIREADER__ +#define __INIREADER__ + +#include +#include +#include + +inline INIReader::INIReader(std::string filename) +{ + _error = ini_parse(filename.c_str(), ValueHandler, this); +} + +inline INIReader::INIReader(FILE *file) +{ + _error = ini_parse_file(file, ValueHandler, this); +} + +inline int INIReader::ParseError() const +{ + return _error; +} + +inline INIReader::~INIReader() { } + +inline const std::set& INIReader::Sections() const +{ + return _sections; +} + +inline std::string INIReader::Get(std::string section, std::string name, std::string default_value) const +{ + std::string key = MakeKey(section, name); + return _values.count(key) ? _values.at(key) : default_value; +} + +inline std::string INIReader::Get(std::string section, std::string name) const +{ + std::string key = MakeKey(section, name); + if(_values.count(key)) return _values.at(key); + else + { + printf("[ERROR] Does not find the section %s with name %s. \n", section.c_str(), name.c_str()); + exit(-1); + } +} + +inline long INIReader::GetInteger(std::string section, std::string name, long default_value) const +{ + std::string valstr = Get(section, name, ""); + const char* value = valstr.c_str(); + char* end; + // This parses "1234" (decimal) and also "0x4D2" (hex) + long n = strtol(value, &end, 0); + return end > value ? n : default_value; +} + +inline long INIReader::GetInteger(std::string section, std::string name) const +{ + std::string valstr = Get(section, name, ""); + const char* value = valstr.c_str(); + char* end; + // This parses "1234" (decimal) and also "0x4D2" (hex) + long n = strtol(value, &end, 0); + if(end <= value) + { + printf("[ERROR] Does not find the section %s with name %s. \n", section.c_str(), name.c_str()); + exit(-1); + } + return n; +} + +inline double INIReader::GetReal(std::string section, std::string name, double default_value) const +{ + std::string valstr = Get(section, name, ""); + const char* value = valstr.c_str(); + char* end; + double n = strtod(value, &end); + return end > value ? n : default_value; +} + +inline float INIReader::GetFloat(std::string section, std::string name, float default_value) const +{ + std::string valstr = Get(section, name, ""); + const char* value = valstr.c_str(); + char* end; + float n = strtof(value, &end); + return end > value ? n : default_value; +} + +inline float INIReader::GetFloat(std::string section, std::string name) const +{ + std::string valstr = Get(section, name, ""); + const char* value = valstr.c_str(); + char* end; + float n = strtof(value, &end); + if(end <= value) + { + printf("[ERROR] Does not find the section %s with name %s. \n", section.c_str(), name.c_str()); + exit(-1); + } + return n; +} + +inline bool INIReader::GetBoolean(std::string section, std::string name, bool default_value) const +{ + std::string valstr = Get(section, name, ""); + // Convert to lower case to make string comparisons case-insensitive + std::transform(valstr.begin(), valstr.end(), valstr.begin(), ::tolower); + if (valstr == "true" || valstr == "yes" || valstr == "on" || valstr == "1") + return true; + else if (valstr == "false" || valstr == "no" || valstr == "off" || valstr == "0") + return false; + else + return default_value; +} + +inline std::string INIReader::MakeKey(std::string section, std::string name) +{ + std::string key = section + "=" + name; + // Convert to lower case to make section/name lookups case-insensitive + std::transform(key.begin(), key.end(), key.begin(), ::tolower); + return key; +} + +inline int INIReader::ValueHandler(void* user, const char* section, const char* name, + const char* value) +{ + INIReader* reader = (INIReader*)user; + std::string key = MakeKey(section, name); + if (reader->_values[key].size() > 0) + reader->_values[key] += "\n"; + reader->_values[key] += value; + reader->_sections.insert(section); + return 1; +} + +#endif // __INIREADER__ \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100755 index 0000000000000000000000000000000000000000..b81abb3b86a592f47bc072c56e68cb264fce6691 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,399 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.11 FATAL_ERROR) # for PyTorch extensions, version should be greater than 3.13 +project(FasterTransformer LANGUAGES CXX CUDA) + +find_package(CUDA 10.2 REQUIRED) + +if(${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11") + add_definitions("-DENABLE_BF16") + message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.0, enable -DENABLE_BF16 flag") +endif() + +if((${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "11" AND ${CUDA_VERSION_MINOR} VERSION_GREATER_EQUAL "8") OR (${CUDA_VERSION_MAJOR} VERSION_GREATER_EQUAL "12")) + add_definitions("-DENABLE_FP8") + option(ENABLE_FP8 "ENABLE_FP8" OFF) + if(ENABLE_FP8) + message("CUDA_VERSION ${CUDA_VERSION_MAJOR}.${CUDA_VERSION_MINOR} is greater or equal than 11.8, enable -DENABLE_FP8 flag") + endif() +endif() + +set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) + +option(BUILD_PYT "Build in PyTorch TorchScript class mode" OFF) +if(NOT BUILD_MULTI_GPU) + option(BUILD_MULTI_GPU "Build project about multi-GPU" OFF) +endif() +if(NOT USE_TRITONSERVER_DATATYPE) + option(USE_TRITONSERVER_DATATYPE "Build triton backend for triton server" OFF) +endif() + +include(FetchContent) + +FetchContent_Declare( + repo-cutlass + GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git + GIT_TAG cc85b64cf676c45f98a17e3a47c0aafcf817f088 +) + +set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + +FetchContent_MakeAvailable(repo-cutlass) + +set(CUTLASS_HEADER_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass/include) +set(CUTLASS_EXTENSIONS_DIR ${PROJECT_SOURCE_DIR}/src/fastertransformer/cutlass_extensions/include) + +option(SPARSITY_SUPPORT "Build project with Ampere sparsity feature support" OFF) + +option(BUILD_FAST_MATH "Build in fast math mode" ON) + +if(BUILD_MULTI_GPU) + message(STATUS "Add DBUILD_MULTI_GPU, requires MPI and NCCL") + add_definitions("-DBUILD_MULTI_GPU") + set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) + find_package(MPI REQUIRED) + find_package(NCCL REQUIRED) + set(CMAKE_MODULE_PATH "") # prevent the bugs for pytorch building +endif() + +if(BUILD_PYT) + if(DEFINED ENV{NVIDIA_PYTORCH_VERSION}) + if($ENV{NVIDIA_PYTORCH_VERSION} VERSION_LESS "20.03") + message(FATAL_ERROR "NVIDIA PyTorch image is too old for TorchScript mode.") + endif() + if($ENV{NVIDIA_PYTORCH_VERSION} VERSION_EQUAL "20.03") + add_definitions(-DLEGACY_THS=1) + endif() + endif() +endif() + +if(USE_TRITONSERVER_DATATYPE) + message("-- USE_TRITONSERVER_DATATYPE") + add_definitions("-DUSE_TRITONSERVER_DATATYPE") +endif() + +set(CXX_STD "14" CACHE STRING "C++ standard") + +set(CUDA_PATH ${CUDA_TOOLKIT_ROOT_DIR}) + +set(TF_PATH "" CACHE STRING "TensorFlow path") +set(CUSPARSELT_PATH "" CACHE STRING "cuSPARSELt path") + +if((BUILD_TF OR BUILD_TF2) AND NOT TF_PATH) + message(FATAL_ERROR "TF_PATH must be set if BUILD_TF or BUILD_TF2 (=TensorFlow mode) is on.") +endif() + +list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64) + +# profiling +option(USE_NVTX "Whether or not to use nvtx" ON) +if(USE_NVTX) + message(STATUS "NVTX is enabled.") + add_definitions("-DUSE_NVTX") +endif() + +# setting compiler flags +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall -ldl") # -Xptxas -v + +set(SM_SETS 52 60 61 70 75 80 86 89 90) +set(USING_WMMA False) +set(FIND_SM False) + +foreach(SM_NUM IN LISTS SM_SETS) + string(FIND "${SM}" "${SM_NUM}" SM_POS) + if(SM_POS GREATER -1) + if(FIND_SM STREQUAL False) + set(ENV{TORCH_CUDA_ARCH_LIST} "") + endif() + set(FIND_SM True) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM_NUM},code=\\\"sm_${SM_NUM},compute_${SM_NUM}\\\"") + + if (SM_NUM STREQUAL 70 OR SM_NUM STREQUAL 75 OR SM_NUM STREQUAL 80 OR SM_NUM STREQUAL 86 OR SM_NUM STREQUAL 89 OR SM_NUM STREQUAL 90) + set(USING_WMMA True) + endif() + + if(BUILD_PYT) + string(SUBSTRING ${SM_NUM} 0 1 SM_MAJOR) + string(SUBSTRING ${SM_NUM} 1 1 SM_MINOR) + set(ENV{TORCH_CUDA_ARCH_LIST} "$ENV{TORCH_CUDA_ARCH_LIST}\;${SM_MAJOR}.${SM_MINOR}") + endif() + + list(APPEND CMAKE_CUDA_ARCHITECTURES ${SM_NUM}) + message("-- Assign GPU architecture (sm=${SM_NUM})") + endif() +endforeach() + +if(USING_WMMA STREQUAL True) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") + message("-- Use WMMA") +endif() + +if(NOT (FIND_SM STREQUAL True)) + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \ + -gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \ + -gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \ + -gencode=arch=compute_80,code=\\\"sm_80,compute_80\\\" \ + -gencode=arch=compute_86,code=\\\"sm_86,compute_86\\\" \ + ") + # -rdc=true") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA") + if(BUILD_PYT) + set(ENV{TORCH_CUDA_ARCH_LIST} "7.0;7.5;8.0;8.6") + endif() + set(CMAKE_CUDA_ARCHITECTURES 70 75 80 86) + message("-- Assign GPU architecture (sm=70,75,80,86)") +endif() + +if(BUILD_PYT) + set(TORCH_CUDA_ARCH_LIST $ENV{TORCH_CUDA_ARCH_LIST}) +endif() + +set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0") +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0") +# set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall --ptxas-options=-v --resource-usage") +set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall -DCUDA_PTX_FP8_F2FP_ENABLED") + +set(CMAKE_CXX_STANDARD "${CXX_STD}") +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++${CXX_STD} -DCUDA_PTX_FP8_F2FP_ENABLED") + +set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") +# set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 --ptxas-options=--verbose") +set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3 -DCUDA_PTX_FP8_F2FP_ENABLED") +if(BUILD_FAST_MATH) +set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} --use_fast_math") +message("CMAKE_CUDA_FLAGS_RELEASE: ${CMAKE_CUDA_FLAGS_RELEASE}") +endif() + +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +set(COMMON_HEADER_DIRS + ${PROJECT_SOURCE_DIR} + ${CUDA_PATH}/include + ${CUTLASS_HEADER_DIR} +) +message("-- COMMON_HEADER_DIRS: ${COMMON_HEADER_DIRS}") + +set(COMMON_LIB_DIRS + ${CUDA_PATH}/lib64 +) + +if (SPARSITY_SUPPORT) + list(APPEND COMMON_HEADER_DIRS ${CUSPARSELT_PATH}/include) + list(APPEND COMMON_LIB_DIRS ${CUSPARSELT_PATH}/lib64) + add_definitions(-DSPARSITY_ENABLED=1) +endif() + +if(BUILD_TF) + list(APPEND COMMON_HEADER_DIRS ${TF_PATH}/include) + list(APPEND COMMON_LIB_DIRS ${TF_PATH}) + add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +endif() + +if(BUILD_TF2) + list(APPEND COMMON_HEADER_DIRS ${TF_PATH}/include) + list(APPEND COMMON_LIB_DIRS ${TF_PATH}) + add_definitions(-D_GLIBCXX_USE_CXX11_ABI=1) +endif() + +set(PYTHON_PATH "python" CACHE STRING "Python path") +if(BUILD_PYT) + execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import torch; print(torch.__version__,end='');" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE TORCH_VERSION) + if (TORCH_VERSION VERSION_LESS "1.5.0") + message(FATAL_ERROR "PyTorch >= 1.5.0 is needed for TorchScript mode.") + endif() + execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import os; import torch; +print(os.path.dirname(torch.__file__),end='');" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE TORCH_DIR) + if (NOT _PYTHON_SUCCESS MATCHES 0) + message(FATAL_ERROR "Torch config Error.") + endif() + list(APPEND CMAKE_PREFIX_PATH ${TORCH_DIR}) + find_package(Torch REQUIRED) + execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; from distutils import sysconfig; +print(sysconfig.get_python_inc());" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE PY_INCLUDE_DIR) + if (NOT _PYTHON_SUCCESS MATCHES 0) + message(FATAL_ERROR "Python config Error.") + endif() + list(APPEND COMMON_HEADER_DIRS ${PY_INCLUDE_DIR}) + execute_process(COMMAND ${PYTHON_PATH} "-c" "from __future__ import print_function; import torch; +print(torch._C._GLIBCXX_USE_CXX11_ABI,end='');" + RESULT_VARIABLE _PYTHON_SUCCESS + OUTPUT_VARIABLE USE_CXX11_ABI) + message("-- USE_CXX11_ABI=${USE_CXX11_ABI}") + if (USE_CXX11_ABI) + set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=1") + set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=1") + else() + set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -D_GLIBCXX_USE_CXX11_ABI=0") + set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0") + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -D_GLIBCXX_USE_CXX11_ABI=0") + endif() +endif() + +if (BUILD_MULTI_GPU) + list(APPEND COMMON_HEADER_DIRS ${MPI_INCLUDE_PATH}) + list(APPEND COMMON_LIB_DIRS /usr/local/mpi/lib) +endif() + +if(USE_TRITONSERVER_DATATYPE) + list(APPEND COMMON_HEADER_DIRS ${PROJECT_SOURCE_DIR}/../repo-core-src/include) +endif() + +include_directories( + ${COMMON_HEADER_DIRS} +) + +link_directories( + ${COMMON_LIB_DIRS} +) + +# add_subdirectory(3rdparty) +add_subdirectory(src) +add_subdirectory(examples) + +add_subdirectory(tests) + +# # Mesaure the compile time +option(MEASURE_BUILD_TIME "Measure the build time of each module" OFF) +if (MEASURE_BUILD_TIME) + set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_COMMAND} -E time") + set_property(GLOBAL PROPERTY RULE_LAUNCH_CUSTOM "${CMAKE_COMMAND} -E time") + set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK "${CMAKE_COMMAND} -E time") +endif() + +######################################## + +add_library(transformer-shared SHARED + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ + $ +) + +if (BUILD_MULTI_GPU) +target_link_libraries(transformer-shared PUBLIC + -lmpi + ${NCCL_LIBRARIES} +) +endif() + +if(USE_NVTX) +target_link_libraries(transformer-shared PUBLIC + -lnvToolsExt +) +endif() + +set_target_properties(transformer-shared PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(transformer-shared PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON) +set_target_properties(transformer-shared PROPERTIES LINKER_LANGUAGE CXX) +target_link_libraries(transformer-shared PUBLIC -lcudart -lcublas -lcublasLt -lcurand) + +include(GNUInstallDirs) +set(INSTALL_CONFIGDIR ${CMAKE_INSTALL_LIBDIR}/cmake/FasterTransformer) + +include(CMakePackageConfigHelpers) +configure_package_config_file( + ${CMAKE_CURRENT_LIST_DIR}/cmake/FasterTransformerConfig.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerConfig.cmake + INSTALL_DESTINATION ${INSTALL_CONFIGDIR} +) + +install( + FILES + ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerConfig.cmake + DESTINATION ${INSTALL_CONFIGDIR} +) + +install( + TARGETS + transformer-shared + EXPORT + transformer-shared-targets + LIBRARY DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/fastertransformer + ARCHIVE DESTINATION ${CMAKE_INSTALL_PREFIX}/backends/fastertransformer +) + +install( + EXPORT + transformer-shared-targets + FILE + FasterTransformerTargets.cmake + DESTINATION + ${INSTALL_CONFIGDIR} +) + +export( + EXPORT + transformer-shared-targets + FILE + ${CMAKE_CURRENT_BINARY_DIR}/FasterTransformerTargets.cmake + NAMESPACE + TritonCore:: +) + +export(PACKAGE FasterTransformer) diff --git a/cmake/FasterTransformerConfig.cmake.in b/cmake/FasterTransformerConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..290213c9699e66cefbbe6ebe2b81be2f0c13fedb --- /dev/null +++ b/cmake/FasterTransformerConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + FASTERTRANSFORMER_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${FASTERTRANSFORMER_CMAKE_DIR}) + +if(NOT TARGET transformer-shared) + include("${FASTERTRANSFORMER_CMAKE_DIR}/FasterTransformerTargets.cmake") +endif() + +set(FASTERTRANSFORMER_LIBRARIES transformer-shared) diff --git a/cmake/Modules/FindCUDNN.cmake b/cmake/Modules/FindCUDNN.cmake new file mode 100644 index 0000000000000000000000000000000000000000..7e7fc0c9391e661e14c5c4d9210abeb04be94dda --- /dev/null +++ b/cmake/Modules/FindCUDNN.cmake @@ -0,0 +1,51 @@ +# taken from https://github.com/pytorch/pytorch/blob/master/cmake/Modules_CUDA_fix/FindCUDNN.cmake +# Find the CUDNN libraries +# +# The following variables are optionally searched for defaults +# CUDNN_ROOT: Base directory where CUDNN is found +# CUDNN_INCLUDE_DIR: Directory where CUDNN header is searched for +# CUDNN_LIBRARY: Directory where CUDNN library is searched for +# CUDNN_STATIC: Are we looking for a static library? (default: no) +# +# The following are set after configuration is done: +# CUDNN_FOUND +# CUDNN_INCLUDE_PATH +# CUDNN_LIBRARY_PATH +# + +include(FindPackageHandleStandardArgs) + +set(CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} CACHE PATH "Folder containing NVIDIA cuDNN") +if (DEFINED $ENV{CUDNN_ROOT_DIR}) + message(WARNING "CUDNN_ROOT_DIR is deprecated. Please set CUDNN_ROOT instead.") +endif() +list(APPEND CUDNN_ROOT $ENV{CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) + +# Compatible layer for CMake <3.12. CUDNN_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${CUDNN_ROOT}) + +set(CUDNN_INCLUDE_DIR $ENV{CUDNN_INCLUDE_DIR} CACHE PATH "Folder containing NVIDIA cuDNN header files") + +find_path(CUDNN_INCLUDE_PATH cudnn.h + HINTS ${CUDNN_INCLUDE_DIR} + PATH_SUFFIXES cuda/include cuda include) + +option(CUDNN_STATIC "Look for static CUDNN" OFF) +if (CUDNN_STATIC) + set(CUDNN_LIBNAME "libcudnn_static.a") +else() + set(CUDNN_LIBNAME "cudnn") +endif() + +set(CUDNN_LIBRARY $ENV{CUDNN_LIBRARY} CACHE PATH "Path to the cudnn library file (e.g., libcudnn.so)") +if (CUDNN_LIBRARY MATCHES ".*cudnn_static.a" AND NOT CUDNN_STATIC) + message(WARNING "CUDNN_LIBRARY points to a static library (${CUDNN_LIBRARY}) but CUDNN_STATIC is OFF.") +endif() + +find_library(CUDNN_LIBRARY_PATH ${CUDNN_LIBNAME} + PATHS ${CUDNN_LIBRARY} + PATH_SUFFIXES lib lib64 cuda/lib cuda/lib64 lib/x64) + +find_package_handle_standard_args(CUDNN DEFAULT_MSG CUDNN_LIBRARY_PATH CUDNN_INCLUDE_PATH) + +mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY) diff --git a/cmake/Modules/FindNCCL.cmake b/cmake/Modules/FindNCCL.cmake new file mode 100644 index 0000000000000000000000000000000000000000..d2f2f8358af1df739f918c25d1a9405e7dd32979 --- /dev/null +++ b/cmake/Modules/FindNCCL.cmake @@ -0,0 +1,165 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# From PyTorch: +# +# Copyright (c) 2016- Facebook, Inc (Adam Paszke) +# Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +# Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +# Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +# Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +# Copyright (c) 2011-2013 NYU (Clement Farabet) +# Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +# Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +# Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) +# +# From Caffe2: +# +# Copyright (c) 2016-present, Facebook Inc. All rights reserved. +# +# All contributions by Facebook: +# Copyright (c) 2016 Facebook Inc. +# +# All contributions by Google: +# Copyright (c) 2015 Google Inc. +# All rights reserved. +# +# All contributions by Yangqing Jia: +# Copyright (c) 2015 Yangqing Jia +# All rights reserved. +# +# All contributions by Kakao Brain: +# Copyright 2019-2020 Kakao Brain +# +# All contributions from Caffe: +# Copyright(c) 2013, 2014, 2015, the respective contributors +# All rights reserved. +# +# All other contributions: +# Copyright(c) 2015, 2016 the respective contributors +# All rights reserved. +# +# Caffe2 uses a copyright model similar to Caffe: each contributor holds +# copyright over their contributions to Caffe2. The project versioning records +# all such contribution and copyright details. If a contributor wants to further +# mark their specific copyright on a particular contribution, they should +# indicate their copyright solely in the commit message of the change when it is +# committed. +# +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America +# and IDIAP Research Institute nor the names of its contributors may be +# used to endorse or promote products derived from this software without +# specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +# Find the nccl libraries +# +# The following variables are optionally searched for defaults +# NCCL_ROOT: Base directory where all NCCL components are foundHong Xu, 1 year ago: • Let CMake handle NCCL detection instead of ou… +# NCCL_INCLUDE_DIR: Directory where NCCL header is foundPieter Noordhuis, 3 years ago: • Bump gloo +# NCCL_LIB_DIR: Directory where NCCL library is found +# +# The following are set after configuration is done: +# NCCL_FOUND +# NCCL_INCLUDE_DIRS +# NCCL_LIBRARIES +# +# The path hints include CUDA_TOOLKIT_ROOT_DIR seeing as some folks +# install NCCL in the same location as the CUDA toolkit. +# See https://github.com/caffe2/caffe2/issues/1601 + +set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers") +set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries") +set(NCCL_VERSION $ENV{NCCL_VERSION} CACHE STRING "Version of NCCL to build with") + +if ($ENV{NCCL_ROOT_DIR}) + message(WARNING "NCCL_ROOT_DIR is deprecated. Please set NCCL_ROOT instead.") +endif() +list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) +# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT}) + +find_path(NCCL_INCLUDE_DIRS + NAMES nccl.h + HINTS ${NCCL_INCLUDE_DIR}) + +if (USE_STATIC_NCCL) + MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.") + SET(NCCL_LIBNAME "nccl_static") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + SET(NCCL_LIBNAME "nccl") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +endif() + +find_library(NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + +if(NCCL_FOUND) # obtaining NCCL version and some sanity checks + set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") + message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...") + set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS}) + include(CheckCXXSymbolExists) + check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) + + if (NCCL_VERSION_DEFINED) + set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") + file(WRITE ${file} " + #include + #include + int main() + { + std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; + int x; + ncclGetVersion(&x); + return x == NCCL_VERSION_CODE; + } +") + try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}" + LINK_LIBRARIES ${NCCL_LIBRARIES}) + if (NOT NCCL_VERSION_MATCHED) + message(FATAL_ERROR "Found NCCL header version and library version do not match! \ +(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") + endif() + message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") + else() + # message(STATUS "NCCL version < 2.3.5-5") + endif () + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/cmake/TritonFasterTransformerBackendConfig.cmake.in b/cmake/TritonFasterTransformerBackendConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..61a4a5489a80d9d571cf1cddbee4840f70228e13 --- /dev/null +++ b/cmake/TritonFasterTransformerBackendConfig.cmake.in @@ -0,0 +1,39 @@ +# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include(CMakeFindDependencyMacro) + +get_filename_component( + TRITONPYTORCHBACKEND_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH +) + +list(APPEND CMAKE_MODULE_PATH ${TRITONPYTORCHBACKEND_CMAKE_DIR}) + +if(NOT TARGET TritonPyTorchBackend::triton-pytorch-backend) + include("${TRITONPYTORCHBACKEND_CMAKE_DIR}/TritonPyTorchBackendTargets.cmake") +endif() + +set(TRITONPYTORCHBACKEND_LIBRARIES TritonPyTorchBackend::triton-pytorch-backend) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d2fd51ada3dd216e9c2facb69551131a5d90a91d --- /dev/null +++ b/examples/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(cpp) \ No newline at end of file diff --git a/examples/__init__.py b/examples/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ddbf5e9c21a53a000fcfdd50aa6a9670093e24a7 --- /dev/null +++ b/examples/cpp/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(llama) \ No newline at end of file diff --git a/examples/cpp/llama/CMakeLists.txt b/examples/cpp/llama/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d0f0dae55690aecb1827d651a885433cee5ad4fb --- /dev/null +++ b/examples/cpp/llama/CMakeLists.txt @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +add_executable(llama_triton_example llama_triton_example.cc) +target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart + LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils + nvtx_utils word_list glog) \ No newline at end of file diff --git a/examples/cpp/llama/generate_gemm_config.py b/examples/cpp/llama/generate_gemm_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e538e6d40f46796ffd008a4cd69111e39b601b4f --- /dev/null +++ b/examples/cpp/llama/generate_gemm_config.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import subprocess +import fire + + +def main(head_num: int = 80, + size_per_head: int = 128, + vocab_size: int = 65632, + inter_size: int = 27392, + tensor_para_size: int = 8, + max_batch_size: int = 64): + for bsz in range(1, max_batch_size + 1): + subprocess.call( + f'bin/gpt_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size} {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}', + shell=True) + + +if __name__ == '__main__': + fire.Fire(main) diff --git a/examples/cpp/llama/llama_ckpt_convert.py b/examples/cpp/llama/llama_ckpt_convert.py new file mode 100644 index 0000000000000000000000000000000000000000..a3ae4d1d0060d436be3534176f2d26ff2d8ab2d1 --- /dev/null +++ b/examples/cpp/llama/llama_ckpt_convert.py @@ -0,0 +1,210 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +import fire +import os.path as osp +from os import makedirs +from pathlib import Path +import safetensors +from typing import List +from tqdm import tqdm + + +def import_fb(ckpt_dir: str): + checkpoints = [] + for pattern in ['*.pth', '*.pt']: + checkpoints += sorted(Path(ckpt_dir).glob(pattern)) + print(checkpoints) + n_ckpt = len(checkpoints) + model_params = {} + + def get_param(name, size): + print(name, size) + if name not in model_params: + model_params[name] = torch.zeros( + size, dtype=torch.float16, device='cpu') + return model_params[name] + for i, ckpt_path in enumerate(checkpoints): + ckpt = torch.load(ckpt_path, map_location='cpu') + for param_name, param_data in ckpt.items(): + key = param_name.split('.')[-2] + if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'output']: # column-parallel + size = param_data.size(0) + param = get_param( + param_name, [size * n_ckpt, param_data.size(1)]) + param.data[size * i: size * (i + 1), :] = param_data + elif key in ['w2', 'wo', 'tok_embeddings']: # row-parallel + size = param_data.size(-1) + param = get_param( + param_name, [param_data.size(0), size * n_ckpt]) + param.data[:, size * i: size * (i + 1)] = param_data + elif i == 0: + param = get_param(param_name, param_data.size()) + param.data = param_data + del ckpt + + for name, param in model_params.items(): + # transpose all weights as FasterTransformer is expecting column-major weights + # (output_dims, input_dims) -> (input_dims, output_dims) + key = name.split('.')[-2] + if key in ['w1', 'w3', 'wq', 'wk', 'wv', 'w2', 'wo']: + param.data = param.data.t() + + # concat qkv projection + for i in range(1000): + _qkv = [f'layers.{i}.attention.{k}.weight' for k in ['wq', 'wk', 'wv']] + try: + qkv = tuple(map(model_params.pop, _qkv)) + except KeyError: + break + qkv = torch.stack(qkv, dim=1) + model_params[f'layers.{i}.attention.w_qkv.weight'] = qkv + print(qkv.shape, qkv.dtype) + + return model_params + + +def permute(x: torch.Tensor): + SIZE_PER_HEAD = 128 + if x.shape[-1] > 1: # qweights + dim = x.shape[-1] + n_heads = dim // SIZE_PER_HEAD + return x.view(-1, n_heads, 2, dim // n_heads // 2).transpose(2, 3).reshape(-1, dim) + else: # scales, zeros + dim = x.shape[0] + n_heads = dim // SIZE_PER_HEAD + return x.view(n_heads, 2, dim // n_heads // 2, 1).transpose(1, 2).reshape(dim, 1) + + +def check_zero(x: torch.Tensor): + sum = x.flatten().sum().item() + assert sum == 0, str(sum) + + +def import_gptq(path: str): + model_params = {} + + _qweight = 'weight' + _suffixes = [_qweight] + n_split = 3 + if True: + _params = {} + for i in tqdm(range(0, n_split)): + filename = "pytorch_model-{:05d}-of-{:05d}.bin".format(i + 1, n_split) + _tmp = torch.load(osp.join(path, filename), map_location='cpu') + _params.update(_tmp) + # print('\n'.join(_params.keys())) + def get_tensor(name): + return _params[name] + def get_tensor_transposed(name): + return _params[name].t() + + # _qweight = 'qweight' + # _suffixes = [_qweight, 'bias', 'scales', 'zeros'] + # with safetensors.safe_open(path, framework='pt') as f: + # get_tensor = f.get_tensor + # # quantized weights are already in column major, no need to transpose + # get_tensor_transposed = get_tensor + for i in range(1000): + try: + # attention weights + _qkvo = [f'model.layers.{i}.self_attn.{t}_proj' for t in 'qkvo'] + for suffix in _suffixes: + q, k, v, o = map(get_tensor_transposed, map(('{}.' + suffix).format, _qkvo)) + if suffix == 'bias': + check_zero(q), check_zero(k), check_zero(v), check_zero(o) + else: + # q, k has different layout for fb & hf, convert to fb's layout + q = permute(q) + k = permute(k) + if suffix == _qweight: # weight, qweight + # insert a dimension for splitting heads later + # qkv = torch.cat([q[:, None, :], k[:, None, :], v[:, None, :]], dim=1) + qkv = torch.stack((q, k, v), dim=1) + else: # scales, zeros + # qkv = torch.cat([q[None, :], k[None, :], v[None, :]], dim=0).squeeze(dim=-1) + qkv = torch.stack((q, k, v), dim=0).squeeze(dim=-1) + for k, v in [('w_qkv', qkv), ('wo', o)]: + model_params[f'layers.{i}.attention.{k}.{suffix}'] = v + # ffn weights + _w123 = [f'model.layers.{i}.mlp.{t}_proj' for t in ['gate', 'down', 'up']] + for suffix in _suffixes: + w1, w2, w3 = map(get_tensor_transposed, map(('{}.' + suffix).format, _w123)) + if suffix == 'bias': + check_zero(w1), check_zero(w2), check_zero(w3) + else: + if suffix in ['scales', 'zeros']: + w1, w2, w3 = map(lambda x: x.squeeze(dim=-1), [w1, w2, w3]) + for k, v in [('w1', w1), ('w2', w2), ('w3', w3)]: + model_params[f'layers.{i}.feed_forward.{k}.{suffix}'] = v + other = [('attention_norm.weight', 'input_layernorm.weight'), + ('ffn_norm.weight', 'post_attention_layernorm.weight')] + for ours, theirs in other: + model_params[f'layers.{i}.' + ours] = get_tensor(f'model.layers.{i}.' + theirs) + except safetensors.SafetensorError: + break + except KeyError: + break + print(i) + + other = [('tok_embeddings.weight', 'model.embed_tokens.weight'), + ('norm.weight', 'model.norm.weight'), + ('output.weight', 'lm_head.weight')] + for ours, theirs in other: + model_params[ours] = get_tensor(theirs) + + return model_params + + +def export(model_params: dict, out_dir: str, n_inference: int): + makedirs(out_dir, exist_ok=True) + + def save_bin(param: torch.Tensor, name): + print(name, param.shape) + if param.dtype in [torch.float, torch.bfloat16]: + param = param.half() + param.contiguous().numpy().tofile(osp.join(out_dir, name)) + + # reverse the spliting axes since the weights are transposed above + for param_name, param_data in model_params.items(): + split_dim = None + key, ext = param_name.split('.')[-2:] + copy = False + if key in ['w1', 'w3', 'w_qkv']: + split_dim = -1 + elif key in ['w2', 'wo']: + if ext in ['scales', 'zeros']: + copy = True + else: + split_dim = 0 + if split_dim is not None: + print(f'*** spliting {param_name}, shape={param_data.shape}, split_dim={split_dim}') + assert param_data.shape[split_dim] % n_inference == 0 + split_size = param_data.shape[split_dim] // n_inference + splits = torch.split(param_data, split_size, dim=split_dim) + for i, split in enumerate(splits): + prefix, ext = osp.splitext(param_name) + save_bin(split, f'{prefix}.{i}{ext}') + elif copy: + print(f'### copying {param_name}, shape={param_data.shape}') + copies = [param_data] * n_inference + for i, copy in enumerate(copies): + prefix, ext = osp.splitext(param_name) + save_bin(copy, f'{prefix}.{i}{ext}') + else: + save_bin(param_data, param_name) + + +def main(kind: str, input_path: str, out_dir: str, n_inference: int = 1): + if kind == 'fb': + model_params = import_fb(input_path) + elif kind == 'gptq': + model_params = import_gptq(input_path) + else: + raise RuntimeError(f'Unsupported kind: {kind}') + + export(model_params, out_dir, n_inference) + + +if __name__ == '__main__': + fire.Fire(main) \ No newline at end of file diff --git a/examples/cpp/llama/llama_config.ini b/examples/cpp/llama/llama_config.ini new file mode 100644 index 0000000000000000000000000000000000000000..09c662d8962d1b126b4b8ea12e879d71e0b999eb --- /dev/null +++ b/examples/cpp/llama/llama_config.ini @@ -0,0 +1,82 @@ +[ft_instance_hyperparameter] +data_type=fp16 +enable_custom_all_reduce=0 +pipeline_para_size=1 +tensor_para_size=8 +model_dir=/shared_data/chatpjlm-0/v0.2.3/fastertransformer/weights/ + + +[request] +request_batch_size=8 +request_output_len=2048 +beam_width=1 ; beam width for beam search +top_k=1 ; k value for top k sampling +top_p=0.0 ; p value for top p sampling +temperature=1.0 ; Use for sampling +repetition_penalty=1.00 ; Use for sampling +presence_penalty=0.0 ; Only one of repetition_penalty and presence_penalty are allowed. +len_penalty=0.0 +beam_search_diversity_rate=0.0 +; PJLM start/end ids +start_id=0 +end_id=1 + + +; --------------------- legacy params ------------------------- + +; LLaMA start/end ids +; start_id=1 +; end_id=2 + +[4999_llama] +head_num=80 +size_per_head=128 +vocab_size=65632 +num_layer=82 +rotary_embedding=128 +norm_eps=1e-5 +start_id=0 +end_id=1 +inter_size=27392 + +[llama_7B] +head_num=32 +size_per_head=128 +vocab_size=32000 +num_layer=32 +rotary_embedding=128 +start_id=1 +end_id=2 +inter_size=11008 + +[llama_13B] +head_num=40 +size_per_head=128 +vocab_size=32000 +num_layer=40 +rotary_embedding=128 +start_id=1 +end_id=2 +inter_size=13824 + +[llama_30B] +head_num=52 +size_per_head=128 +vocab_size=32000 +num_layer=60 +rotary_embedding=128 +start_id=1 +end_id=2 +inter_size=17920 + +[llama_65B] +head_num=64 +size_per_head=128 +vocab_size=32000 +num_layer=80 +rotary_embedding=128 +start_id=1 +end_id=2 +inter_size=22016 + + diff --git a/examples/cpp/llama/llama_triton_example.cc b/examples/cpp/llama/llama_triton_example.cc new file mode 100644 index 0000000000000000000000000000000000000000..6e698b040177243555cd6f47adbe7f6e19d18582 --- /dev/null +++ b/examples/cpp/llama/llama_triton_example.cc @@ -0,0 +1,584 @@ +/* + * Copyright (c) OpenMMLab. All rights reserved. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Modified from https://github.com/NVIDIA/FasterTransformer/blob/main/examples/cpp/multi_gpu_gpt/multi_gpu_gpt_triton_example.cc + +#include "3rdparty/INIReader.h" +#include +#include + +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/transformer_triton_backend.hpp" +#include "src/fastertransformer/utils/custom_ar_comm.h" +#include "src/fastertransformer/utils/mpi_utils.h" +#include "src/fastertransformer/utils/nccl_utils.h" +#include "src/fastertransformer/utils/nvtx_utils.h" +#include "src/fastertransformer/utils/word_list.h" + +namespace ft = fastertransformer; + +constexpr const bool kUSE_MPI = true; + +struct RequestParam { + int beam_width; + int request_output_len; + float beam_search_diversity_rate; + uint runtime_top_k; + float runtime_top_p; + float temperature; + float len_penalty; + float repetition_penalty; + float presence_penalty; + int min_length; + unsigned long long int random_seed; + int start_id; + int end_id; +}; + +std::vector>> +broadCastRequest(const std::vector& v_start_ids, + const std::vector& v_start_lengths, + const std::vector& v_bad_words, + const int node_id, + const int gpu_count, + const RequestParam param, + std::vector* pointer_record) +{ + // broadcast the request to all nodes, and copy "gpu_count" copies on + // different gpu + int size_1 = v_start_ids.size(); + int size_2 = v_start_lengths.size(); + int size_bad_words = v_bad_words.size(); + if (kUSE_MPI) { + ft::mpi::bcast(&size_1, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_2, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(&size_bad_words, 1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + } + + std::vector v_input_ids(size_1); + std::vector v_input_lengths(size_2); + std::vector v_input_bad_words(size_bad_words); + + if (node_id == 0) { + memcpy(v_input_ids.data(), v_start_ids.data(), size_1 * sizeof(int)); + memcpy(v_input_lengths.data(), v_start_lengths.data(), size_2 * sizeof(int)); + memcpy(v_input_bad_words.data(), v_bad_words.data(), size_bad_words * sizeof(int)); + } + if (kUSE_MPI) { + ft::mpi::barrier(); + } + + int request_batch_size = size_2; + int max_input_len = size_1 / size_2; + + std::cerr << "request_batch_size=" << request_batch_size << " max_input_len=" << max_input_len << "\n"; + + if (kUSE_MPI) { + ft::mpi::bcast(v_input_ids.data(), size_1, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_lengths.data(), size_2, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + ft::mpi::bcast(v_input_bad_words.data(), size_bad_words, ft::mpi::MPI_TYPE_INT, 0, ft::mpi::COMM_WORLD); + } + + std::vector>> request_list; + for (int device_id = 0; device_id < gpu_count; device_id++) { + ft::check_cuda_error(cudaSetDevice(device_id)); + + int* d_input_ids; + // int* d_input_lengths; + int* d_input_bad_words; + + if (max_input_len == 0) { + // unconditional case, no input ids, so do nothing. + d_input_ids = nullptr; + // d_input_lengths = nullptr; + max_input_len = 0; + } + else { + // conditional case. + ft::deviceMalloc(&d_input_ids, size_1, false); + // ft::deviceMalloc(&d_input_lengths, size_2, false); + ft::cudaH2Dcpy(d_input_ids, v_input_ids.data(), size_1); + // ft::cudaH2Dcpy(d_input_lengths, v_input_lengths.data(), size_2); + } + + if (!v_input_bad_words.empty()) { + ft::deviceMalloc(&d_input_bad_words, size_bad_words, false); + ft::cudaH2Dcpy(d_input_bad_words, v_input_bad_words.data(), size_bad_words); + } + else { + d_input_bad_words = nullptr; + } + + uint32_t* request_output_len_ptr = (uint32_t*)malloc(request_batch_size * sizeof(uint32_t)); + int* input_lengths_ptr = (int*)malloc(request_batch_size * sizeof(int)); + for (int i = 0; i < request_batch_size; i++) { + request_output_len_ptr[i] = param.request_output_len; + input_lengths_ptr[i] = v_input_lengths[i]; + } + + int* start_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + int* end_ids_ptr = (int*)malloc(request_batch_size * sizeof(int)); + for (int i = 0; i < request_batch_size; i++) { + start_ids_ptr[i] = param.start_id; + end_ids_ptr[i] = param.end_id; + } + pointer_record->push_back(start_ids_ptr); + pointer_record->push_back(end_ids_ptr); + + request_list.push_back(std::shared_ptr>( + new std::unordered_map{ + {"input_ids", + triton::Tensor{triton::MEMORY_GPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size, (size_t)max_input_len}, + d_input_ids}}, + {"input_lengths", + triton::Tensor{triton::MEMORY_CPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + input_lengths_ptr}}, + {"request_output_len", + triton::Tensor{triton::MEMORY_CPU, + triton::TYPE_INT32, + std::vector{(size_t)request_batch_size}, + request_output_len_ptr}}, + {"bad_words_list", + triton::Tensor{ + triton::MEMORY_GPU, triton::TYPE_INT32, {2, v_input_bad_words.size() / 2}, d_input_bad_words}}, + {"start_id", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, start_ids_ptr}}, + {"end_id", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, {(size_t)request_batch_size}, end_ids_ptr}}})); + + int* beam_width_ptr = new int(param.beam_width); + pointer_record->push_back(beam_width_ptr); + request_list[device_id]->insert( + {"beam_width", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, beam_width_ptr}}); + if (param.beam_width > 1) { + float* beam_search_diversity_rate_ptr = new float(param.beam_search_diversity_rate); + pointer_record->push_back(beam_search_diversity_rate_ptr); + request_list[device_id]->insert( + {"beam_search_diversity_rate", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, beam_search_diversity_rate_ptr}}); + } + else { + if (param.runtime_top_p != 0.0f) { + float* runtime_top_p_ptr = new float(param.runtime_top_p); + pointer_record->push_back(runtime_top_p_ptr); + request_list[device_id]->insert( + {"runtime_top_p", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, runtime_top_p_ptr}}); + } + if (param.runtime_top_k != 0) { + uint* runtime_top_k_ptr = new uint(param.runtime_top_k); + pointer_record->push_back(runtime_top_k_ptr); + request_list[device_id]->insert( + {"runtime_top_k", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_UINT32, std::vector{1}, runtime_top_k_ptr}}); + } + } + float* temperature_ptr = new float(param.temperature); + pointer_record->push_back(temperature_ptr); + request_list[device_id]->insert( + {"temperature", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, temperature_ptr}}); + float* len_penalty_ptr = new float(param.len_penalty); + pointer_record->push_back(len_penalty_ptr); + request_list[device_id]->insert( + {"len_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, len_penalty_ptr}}); + if (param.repetition_penalty != 1.0f) { + float* repetition_penalty_ptr = new float(param.repetition_penalty); + pointer_record->push_back(repetition_penalty_ptr); + request_list[device_id]->insert( + {"repetition_penalty", + triton::Tensor{ + triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, repetition_penalty_ptr}}); + } + if (param.presence_penalty != 0.0f) { + float* presence_penalty_ptr = new float(param.presence_penalty); + pointer_record->push_back(presence_penalty_ptr); + request_list[device_id]->insert( + {"presence_penalty", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_FP32, std::vector{1}, presence_penalty_ptr}}); + } + int* min_length_ptr = new int(param.min_length); + pointer_record->push_back(min_length_ptr); + request_list[device_id]->insert( + {"min_length", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_INT32, std::vector{1}, min_length_ptr}}); + unsigned long long int* random_seed_ptr = new unsigned long long int(param.random_seed); + pointer_record->push_back(random_seed_ptr); + request_list[device_id]->insert( + {"random_seed", + triton::Tensor{triton::MEMORY_CPU, triton::TYPE_UINT64, std::vector{1}, random_seed_ptr}}); + + pointer_record->push_back(d_input_ids); + // pointer_record->push_back(d_input_lengths); + pointer_record->push_back(d_input_bad_words); + pointer_record->push_back(request_output_len_ptr); + pointer_record->push_back(input_lengths_ptr); + } + + return request_list; +} + +int read_start_ids(size_t batch_size, + std::vector* v_start_lengths, + std::vector* v_start_ids, + size_t& max_input_len, + const int end_id, + const int beam_width, + std::string file_name); + +std::vector>> +prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std::vector* pointer_record) +{ + INIReader reader = INIReader(ini_name); + if (reader.ParseError() < 0) { + std::cout << "[ERROR] Can't load '" << ini_name << "'\n"; + ft::FT_CHECK(false); + } + + const size_t request_batch_size = reader.GetInteger("request", "request_batch_size"); + std::cerr << "request_batch_size=" << request_batch_size << "\n"; + + const int start_id = reader.GetInteger("request", "start_id"); + const int end_id = reader.GetInteger("request", "end_id"); + + std::vector v_start_ids; + std::vector v_start_lengths; + + size_t max_input_len = 0; + read_start_ids(request_batch_size, + &v_start_lengths, + &v_start_ids, + max_input_len, + end_id, + 1, + "../examples/cpp/llama/start_ids.csv"); + // drop requests > request_batch_size + if (v_start_lengths.size() > request_batch_size) { + v_start_lengths.resize(request_batch_size); + v_start_ids.resize(request_batch_size * max_input_len); + } + std::cerr << "max_input_len=" << max_input_len << "\n"; + + std::vector v_bad_words; + // ft::read_word_list("../examples/cpp/llama/bad_words.csv", v_bad_words); + + RequestParam param; + param.beam_width = reader.GetInteger("request", "beam_width"); + param.request_output_len = reader.GetInteger("request", "request_output_len"); + param.beam_search_diversity_rate = reader.GetFloat("request", "beam_search_diversity_rate"); + param.runtime_top_k = reader.GetInteger("request", "top_k"); + param.runtime_top_p = reader.GetFloat("request", "top_p"); + param.temperature = reader.GetFloat("request", "temperature"); + param.len_penalty = reader.GetFloat("request", "len_penalty"); + param.repetition_penalty = reader.GetFloat("request", "repetition_penalty", 1.0f); + param.presence_penalty = reader.GetFloat("request", "presence_penalty", 0.0f); + param.min_length = reader.GetInteger("request", "min_length", 0); + param.random_seed = (unsigned long long int)0; + param.start_id = start_id; + param.end_id = end_id; + + auto request_list = + broadCastRequest(v_start_ids, v_start_lengths, v_bad_words, node_id, gpu_count, param, pointer_record); + return request_list; +} + +int threadCreateModelInstances(std::shared_ptr model, + std::vector>* model_instances, + const int device_id, + const int rank, + std::pair, std::vector> nccl_params, + std::shared_ptr custom_all_reduce_comm = nullptr) +{ + printf("[INFO] rank = %d \n", rank); + ft::check_cuda_error(cudaSetDevice(device_id)); + cudaStream_t stream; + ft::check_cuda_error(cudaStreamCreate(&stream)); + model->createSharedWeights(device_id, rank); + auto model_instance = model->createModelInstance(device_id, rank, stream, nccl_params, custom_all_reduce_comm); + model_instances->at(device_id) = std::move(model_instance); + printf("model instance %d is created \n", device_id); + ft::print_mem_usage(); + return 0; +} + +int threadForward(std::unique_ptr* model_instance, + std::shared_ptr> request, + std::shared_ptr>* output_tensors, + const int device_id, + ft::AbstractInstanceComm* comm) +{ + ft::check_cuda_error(cudaSetDevice(device_id)); + cudaDeviceSynchronize(); + *output_tensors = (*model_instance)->forward(request, comm); + cudaDeviceSynchronize(); + return 0; +} + +int main(int argc, char* argv[]) +{ + /* + Prepare the nccl ids, node id, device id and world size + by MPI or triton + */ + + int node_id = 0; + int node_num = 1; + + if (kUSE_MPI) { + ft::mpi::initialize(&argc, &argv); + node_id = ft::mpi::getCommWorldRank(); + node_num = ft::mpi::getCommWorldSize(); + } + + printf("node_id=%d node_num=%d\n", node_id, node_num); + + // Note: Only supports that all nodes have same gpu count + const int gpu_count = ft::getDeviceCount(); + const int world_size = node_num * gpu_count; + std::string ini_name = argc >= 2 ? std::string(argv[1]) : "../examples/cpp/llama/llama_config.ini"; + + // step 1: Create model + std::shared_ptr model = AbstractTransformerModel::createLlamaModel(ini_name); + int tensor_para_size = model->getTensorParaSize(); + int pipeline_para_size = model->getPipelineParaSize(); + printf( + "world_size=%d tensor_para_size=%d pipeline_para_size=%d\n", world_size, tensor_para_size, pipeline_para_size); + FT_CHECK_WITH_INFO(world_size == (tensor_para_size * pipeline_para_size), + "World Size != Tensor Parallel Size * Pipeline Parallel Size !"); + + std::cout << model->toString(); + + // step 2: Initialize the NCCL + std::pair, std::vector> nccl_comms = model->createNcclParams(node_id); + cudaDeviceSynchronize(); + + // Optional Step: create custom all reduce comm + // std::vector> + // custom_all_reduce_comms; model->createCustomComms(&custom_all_reduce_comms, + // world_size); + + // step 2.1 create instance comm + auto instance_comm = model->createInstanceComm(gpu_count); + + // step 3: Create model instances + std::vector> model_instances((size_t)gpu_count); + std::vector threads; + for (int device_id = 0; device_id < gpu_count; device_id++) { + const int rank = node_id * gpu_count + device_id; + threads.push_back( + std::thread(threadCreateModelInstances, model, &model_instances, device_id, rank, nccl_comms, nullptr)); + // custom_all_reduce_comms[rank])); + } + for (auto& t : threads) { + t.join(); + } + + // step 4: prepare request + std::vector pointer_record; // Used to prevent the pointers are + // release after leaving functions + std::vector>> request_list = + prepareRequest(ini_name, node_id, gpu_count, &pointer_record); + printf("[INFO] request is created \n"); + + // step 5: Forward + std::vector>> output_tensors_lists( + (size_t)gpu_count); + for (int i = 0; i < 1; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id, + instance_comm.get())); + } + for (auto& t : threads) { + t.join(); + } + } + printf("[INFO] forward is completed. \n"); + + const int* d_output_ids = (const int*)output_tensors_lists[0].get()->at("output_ids").data; + const int* d_seq_lens = (const int*)output_tensors_lists[0].get()->at("sequence_length").data; + const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; + const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; + const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; + // step 6: check results + if (node_id == 0) { + std::string fName = "out"; + auto outFile = std::ofstream(fName, std::ios::out); + if (!outFile.is_open()) { + printf("[WARNING] Cannot write results into output file %s \n", fName.c_str()); + } + else { + size_t outCount = batch_size * beam_width * seq_len; + // int* hBuf = new int[outCount]; + std::vector hBuf(outCount); + ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount); + std::vector seq_lens(batch_size); + ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size); + std::cout << "sequence length: "; + for (int i = 0; i < batch_size; ++i) { + std::cout << (i ? ", " : "") << seq_lens[i]; + } + std::cout << "\n"; + { + std::cout << "Writing " << outCount << " elements\n"; + int zeroCount = 0; + for (size_t i = 0; i < outCount; i++) { + if (hBuf[i] == int(0)) + zeroCount++; + outFile << hBuf[i] << " "; + if ((i + 1) % (seq_len) == 0) + outFile << std::endl; + + if (i < 10) + printf("%5d ", hBuf[i]); + if ((i + 1) % (seq_len) == 0 && i < 10) + std::cout << std::endl; + } + std::cout << std::endl << "zeroCount = " << zeroCount << std::endl; + } + } + } + + if (kUSE_MPI) { + ft::mpi::barrier(); + } + cudaDeviceSynchronize(); + + if (1) { + // test time + struct timeval start, end; + gettimeofday(&start, NULL); + + const int ite = 1; + for (int i = 0; i < ite; i++) { + threads.clear(); + for (int device_id = 0; device_id < gpu_count; device_id++) { + threads.push_back(std::thread(threadForward, + &model_instances[device_id], + request_list[device_id], + &output_tensors_lists[device_id], + device_id, + instance_comm.get())); + } + for (auto& t : threads) { + t.join(); + } + } + + cudaDeviceSynchronize(); + if (kUSE_MPI) { + ft::mpi::barrier(); + } + + gettimeofday(&end, NULL); + + printf("[INFO] batch_size %d beam_width %d seq_len %d" + " FT-CPP-GPT-Triton-time %.2f ms\n", + batch_size, + beam_width, + seq_len, + ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); + } + + if (kUSE_MPI) { + ft::mpi::finalize(); + } + return 0; +} + +int read_start_ids(size_t batch_size, + std::vector* v_start_lengths, + std::vector* v_start_ids, + size_t& max_input_len, + const int end_id, + const int beam_width, + std::string file_name) +{ + std::vector> tmp_start_ids; + std::vector tmp_start_lengths; + + std::ifstream start_id_file(file_name, std::ios::in); + int line_num = 0; + if (start_id_file.is_open()) { + std::string line; + while (std::getline(start_id_file, line)) { + std::stringstream lineStream(line); + std::string vals; + int i1 = 0; + std::vector tmp_vec; + while (std::getline(lineStream, vals, ',')) { + tmp_vec.push_back(std::stoi(vals)); + i1++; + } + tmp_start_ids.push_back(tmp_vec); + tmp_start_lengths.push_back(i1); + line_num++; + } + if (batch_size == 0) { + batch_size = line_num; + } + } + else { + printf("[WARNING] Cannot open the file '%s'. \n", file_name.c_str()); + max_input_len = 0; + return 0; + } + + max_input_len = tmp_start_lengths.data()[0]; + for (uint i = 1; i < (uint)tmp_start_lengths.size(); i++) { + max_input_len = max_input_len > tmp_start_lengths.data()[i] ? max_input_len : tmp_start_lengths.data()[i]; + } + + while ((int)tmp_start_lengths.size() < batch_size) { + std::vector padding_ids; + for (int i = 0; i < max_input_len; i++) { + padding_ids.push_back(end_id); + } + tmp_start_ids.push_back(padding_ids); + tmp_start_lengths.push_back(max_input_len); + } + + // Add padding + for (int i = 0; i < (int)tmp_start_ids.size(); i++) { + for (int j = (int)tmp_start_ids[i].size(); j < max_input_len; j++) { + tmp_start_ids[i].push_back(end_id); + } + } + + for (int i = 0; i < (int)tmp_start_ids.size(); i++) { + for (int b = 0; b < beam_width; b++) { + for (int j = 0; j < (int)tmp_start_ids[i].size(); j++) { + v_start_ids->push_back(tmp_start_ids[i][j]); + } + v_start_lengths->push_back(tmp_start_lengths[i]); + } + } + return batch_size; +} diff --git a/examples/cpp/llama/start_ids.csv b/examples/cpp/llama/start_ids.csv new file mode 100644 index 0000000000000000000000000000000000000000..1c5d7b09658b59fedc63bceb2a922a4a15663582 --- /dev/null +++ b/examples/cpp/llama/start_ids.csv @@ -0,0 +1,8 @@ +0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,44883,2282,32901,4220,46323,13,44975,45004,11130,32843,45004,35597 +0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,46088,46064,625,19880,46323,13,44975,45004,11130,32843,45004,35597 +0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,47335,56437,60468,46323,13,44975,45004,11130,32843,45004,35597 +0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,44883,2282,6828,3467,46323,13,44975,45004,11130,32843,45004,35597 +0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,36589,3467,7849,299,7032,46323,13,44975,45004,11130,32843,45004,35597 +0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,44976,39798,6828,3467,46323,13,44975,45004,11130,32843,45004,35597 +0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,2795,977,9193,299,405,537,46323,13,44975,45004,11130,32843,45004,35597 +0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,45691,45926,45513,46641,47641,46285,6456,46323,13,44975,45004,11130,32843,45004,35597 \ No newline at end of file diff --git a/examples/cpp/llama/tokenizer.py b/examples/cpp/llama/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9187a95dfe5ece5847df1d184c5c45ade78c7871 --- /dev/null +++ b/examples/cpp/llama/tokenizer.py @@ -0,0 +1,57 @@ +from sentencepiece import SentencePieceProcessor +from typing import List +import fire +import sys + + +class Tokenizer: + def __init__(self, model_file: str): + self.model = SentencePieceProcessor(model_file=model_file) + self.vocab_size = self.model.vocab_size() + self.start_id = self.model.bos_id() + self.end_id = self.model.eos_id() + self.pad_id = self.model.pad_id() + print(f'vocab_size = {self.vocab_size}') + print(f'start_id = {self.start_id}') + print(f'end_id = {self.end_id}') + print(f'pad_id = {self.pad_id}') + + def encode(self, s: str): + return self.model.Encode(s, add_bos=True) + + def decode(self, t: List[int]): + return self.model.Decode(t) + + +def main(model_file: str = '/data/llama/model/tokenizer.model', + encode_file: str = None, decode_file: str = None): + tokenizer = Tokenizer(model_file) + if encode_file: + with open(encode_file, 'r') as f: + xs = tokenizer.encode(f.read()) + print(','.join(map(str, xs))) + elif decode_file: + with open(decode_file, 'r') as f: + ys = tokenizer.decode(f.read()) + print(ys) + else: + first = True + while True: + try: + s = input() + except EOFError: + break + if not first: + print('---------------------------------------------') + first = False + try: + xs = map(int, s.strip().split(' ')) + s = tokenizer.decode(list(xs)) + print(s) + except ValueError: + xs = tokenizer.encode(s) + print(' '.join(map(str, xs))) + + +if __name__ == '__main__': + fire.Fire(main) \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..5474bbe37140e6f73fc15260b93a47f541ffcb1b --- /dev/null +++ b/src/CMakeLists.txt @@ -0,0 +1,15 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(fastertransformer) \ No newline at end of file diff --git a/src/fastertransformer/CMakeLists.txt b/src/fastertransformer/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9baa5329a0320efc291b7a5939cc401c95f4845e --- /dev/null +++ b/src/fastertransformer/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +add_subdirectory(utils) +add_subdirectory(kernels) +add_subdirectory(layers) +add_subdirectory(models) +if(BUILD_PYT) + add_subdirectory(th_op) +endif() +add_subdirectory(triton_backend) \ No newline at end of file diff --git a/src/fastertransformer/kernels/CMakeLists.txt b/src/fastertransformer/kernels/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d718b1fdf17bf10e154e29a4aaadfa0f46860874 --- /dev/null +++ b/src/fastertransformer/kernels/CMakeLists.txt @@ -0,0 +1,89 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +add_library(ban_bad_words STATIC ban_bad_words.cu) +set_property(TARGET ban_bad_words PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET ban_bad_words PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(stop_criteria STATIC stop_criteria_kernels.cu) +set_property(TARGET stop_criteria PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET stop_criteria PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(activation_kernels STATIC activation_kernels.cu) +set_property(TARGET activation_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET activation_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(gen_relative_pos_bias STATIC gen_relative_pos_bias.cu) +set_property(TARGET gen_relative_pos_bias PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET gen_relative_pos_bias PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(gen_relative_pos_bias PUBLIC activation_kernels) + +add_library(logprob_kernels STATIC logprob_kernels.cu) +set_property(TARGET logprob_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET logprob_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(unfused_attention_kernels STATIC unfused_attention_kernels.cu) +set_property(TARGET unfused_attention_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET unfused_attention_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(bert_preprocess_kernels STATIC bert_preprocess_kernels.cu) +set_property(TARGET bert_preprocess_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET bert_preprocess_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +set(decoder_masked_multihead_attention_files + decoder_masked_multihead_attention.cu +) +file(GLOB decoder_masked_multihead_attention_files ${decoder_masked_multihead_attention_files} ./decoder_masked_multihead_attention/*.cu) +add_library(decoder_masked_multihead_attention STATIC ${decoder_masked_multihead_attention_files}) +set_property(TARGET decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(online_softmax_beamsearch_kernels STATIC online_softmax_beamsearch_kernels.cu) +set_property(TARGET online_softmax_beamsearch_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET online_softmax_beamsearch_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(decoding_kernels STATIC decoding_kernels.cu) +set_property(TARGET decoding_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET decoding_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(gpt_kernels STATIC gpt_kernels.cu) +set_property(TARGET gpt_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET gpt_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(beam_search_penalty_kernels STATIC beam_search_penalty_kernels.cu) +set_property(TARGET beam_search_penalty_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET beam_search_penalty_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +target_link_libraries(beam_search_penalty_kernels PRIVATE cuda_utils) + +add_library(beam_search_topk_kernels STATIC beam_search_topk_kernels.cu) +set_property(TARGET beam_search_topk_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET beam_search_topk_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(sampling_topk_kernels STATIC sampling_topk_kernels.cu) +set_property(TARGET sampling_topk_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET sampling_topk_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(sampling_topp_kernels STATIC sampling_topp_kernels.cu) +set_property(TARGET sampling_topp_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET sampling_topp_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(sampling_penalty_kernels STATIC sampling_penalty_kernels.cu) +set_property(TARGET sampling_penalty_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET sampling_penalty_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) + +add_library(custom_ar_kernels STATIC custom_ar_kernels.cu) +set_property(TARGET custom_ar_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET custom_ar_kernels PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/fastertransformer/kernels/activation_kernels.cu b/src/fastertransformer/kernels/activation_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..aa1cd7b10d4d9b8fdd053876ecf2027e8b4b6651 --- /dev/null +++ b/src/fastertransformer/kernels/activation_kernels.cu @@ -0,0 +1,658 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/kernels/activation_kernels.h" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/memory_utils.h" + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#endif + +namespace fastertransformer { + +/* Gelu Activation */ + +__forceinline__ __device__ float copysignf_pos(float a, float b) +{ + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__inline__ __device__ float tanh_opt(float x) +{ +#if (__CUDA_ARCH__ >= 750 && CUDART_VERSION >= 11000) + float r; + asm("tanh.approx.f32 %0,%1; \n\t" : "=f"(r) : "f"(x)); + return r; +#else + const float exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#endif +} + +template +struct GeluActivation { + using return_type = T; + static __device__ __forceinline__ T apply(const T& val) + { + const float cdf = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (val + 0.044715f * val * val * val)))); + return val * cdf; + } +}; + +template<> +struct GeluActivation { + using return_type = half2; + static __device__ __forceinline__ half2 apply(const half2& val) + { + half2 val_pow3 = __hmul2(val, __hmul2(val, val)); + float2 tmp_pow = __half22float2(val_pow3); + float2 tmp = __half22float2(val); + + tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return __hmul2(val, __float22half2_rn(tmp)); + } +}; + +#ifdef ENABLE_BF16 +template<> +struct GeluActivation<__nv_bfloat162> { + using return_type = __nv_bfloat162; + static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val) + { + __nv_bfloat162 val_pow3 = bf16hmul2(val, bf16hmul2(val, val)); + float2 tmp_pow = bf1622float2(val_pow3); + float2 tmp = bf1622float2(val); + + tmp.x = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.x + 0.044715f * tmp_pow.x)))); + tmp.y = 0.5f * (1.0f + tanh_opt((0.7978845608028654f * (tmp.y + 0.044715f * tmp_pow.y)))); + return bf16hmul2(val, __floats2bfloat162_rn(tmp.x, tmp.y)); + } +}; +#endif + +/* Relu Activation */ + +template +struct ReluActivation { + using return_type = T; + static __device__ __forceinline__ T apply(const T& val) + { + return val > static_cast(0.0f) ? val : static_cast(0.0f); + } +}; + +template<> +struct ReluActivation { + using return_type = half2; + static __device__ __forceinline__ half2 apply(const half2& val) + { + const half zero_half = static_cast(0.0f); + return make_half2(val.x > zero_half ? val.x : zero_half, val.y > zero_half ? val.y : zero_half); + } +}; + +#ifdef ENABLE_BF16 +template<> +struct ReluActivation<__nv_bfloat162> { + using return_type = __nv_bfloat162; + static __device__ __forceinline__ __nv_bfloat162 apply(const __nv_bfloat162& val) + { + const __nv_bfloat16 zero_bf16 = static_cast<__nv_bfloat16>(0.0f); + return make_bfloat162(val.x > zero_bf16 ? val.x : zero_bf16, val.y > zero_bf16 ? val.y : zero_bf16); + } +}; +#endif + +/* Silu Activation */ + +template +struct SiluActivation { + using return_type = T; + static __device__ __forceinline__ T apply(const T& val) + { + return (T)((float)val / (1.0f + __expf((float)-val))); + } +}; + +template<> +struct SiluActivation { + using return_type = float2; + static __device__ __forceinline__ float2 apply(const half2& val) + { + return make_float2(SiluActivation::apply(val.x), SiluActivation::apply(val.y)); + } +}; + +#ifdef ENABLE_BF16 +template<> +struct SiluActivation<__nv_bfloat162> { + using return_type = float2; + static __device__ __forceinline__ float2 apply(const __nv_bfloat162& val) + { + return make_float2(SiluActivation::apply(val.x), SiluActivation::apply(val.y)); + } +}; +#endif // ENABLE_BF16 + +/* Identity Activation (= no activation) */ + +template +struct IdentityActivation { + using return_type = T; + static __device__ __forceinline__ T apply(const T& val) + { + return val; + } +}; + +// clang-format off +template class Activation, typename T, typename BT> +__global__ void generic_activation(T* out, + const BT* __restrict bias, + const T* __restrict gated_weights, + const BT* __restrict gated_bias, + const int* __restrict ia3_tasks, + const T* __restrict ia3_weights, + const int int8_mode, + const float* __restrict activation_in, + const float* __restrict activation_out, + const int* __restrict padding_offset, + const int seq_len, + int m, + int n) +{ + constexpr size_t packed_elems = num_elems::value; + + const bool with_bias = bias != nullptr; + const bool with_gate = gated_weights != nullptr; + // const bool with_ia3 = ia3_tasks != nullptr; + + using Act_T = typename Activation::return_type; + using Float_T = typename packed_as::type; + using Packed_Int8_t = typename packed_as::type; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + T val; + if (int8_mode == 2) { + // val = cuda_cast(cuda_cast(reinterpret_cast(out)[id]) * activation_in[0]); + } + else { + val = out[id]; + } + + T gated_val; + if (with_gate) { + gated_val = gated_weights[id]; + } + + // if (with_bias) { + // const T reg_bias = static_cast(bias[id % n]); + // val = val + reg_bias; + + // if (with_gate) { + // const T reg_gated_bias = static_cast(gated_bias[id % n]); + // gated_val = gated_val + reg_gated_bias; + // } + // } + + if (with_gate) { + val = cuda_cast(Activation::apply(val) * cuda_cast(gated_val)); + } + else { + // val = cuda_cast(Activation::apply(val)); + } + + // if (with_ia3) { + // const int word_id = id / n; + // const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id]; + // const int batch_id = (word_id + offset) / seq_len; + // const int task = ia3_tasks[batch_id]; + // val = val * ia3_weights[task * n + (id % n)]; + // } + + if (int8_mode != 2) { + out[id] = val; + } + else { + // reinterpret_cast(out)[id] = + // cuda_cast(cuda_cast(val) * activation_out[0]); + } + } +} +// clang-format on + +template class Activation, typename T, typename BT> +void invokeGenericActivation(T* out, + const BT* bias, + const T* gated_weights, + const BT* gated_bias, + const int* ia3_tasks, + const T* ia3_weights, + const int m, + const int n, + const int int8_mode, + const float* activation_in, + const float* activation_out, + const int* padding_offset, + const int seq_len, + cudaStream_t stream) +{ + FT_LOG_DEBUG(__PRETTY_FUNCTION__); + FT_LOG_DEBUG("invokeGenericActivation %d %d %d", m, n, seq_len); + using PT = typename packed_type::type; + constexpr int packed_elems = num_elems::value; + using PBT = typename packed_as::type; + + const int n_threads = 512; + + dim3 block, grid; + if (n / 4 / packed_elems <= n_threads) { + block.x = n / 4 / packed_elems; + grid.x = m; + } + else { + block.x = n_threads; + grid.x = ceil(m * n / double(n_threads)); + } + FT_LOG_DEBUG("%d %d", grid.x, block.x); + sync_check_cuda_error(); + generic_activation<<>>(reinterpret_cast(out), + reinterpret_cast(bias), + reinterpret_cast(gated_weights), + reinterpret_cast(gated_bias), + ia3_tasks, + reinterpret_cast(ia3_weights), + int8_mode, + activation_in, + activation_out, + padding_offset, + seq_len, + m, + n / packed_elems); + sync_check_cuda_error(); +} + +#define INSTANTIATE_GENERIC_ACTIVATION(Activation, T, BT) \ + template void invokeGenericActivation(T * out, \ + const BT* bias, \ + const T* gated_weights, \ + const BT* gated_bias, \ + const int* ia3_tasks, \ + const T* ia3_weights, \ + const int m, \ + const int n, \ + const int int8_mode, \ + const float* activation_in, \ + const float* activation_out, \ + const int* padding_offset, \ + const int seq_len, \ + cudaStream_t stream); + +INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, float, float); +INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, half, half); +#ifdef ENABLE_BF16 +INSTANTIATE_GENERIC_ACTIVATION(GeluActivation, __nv_bfloat16, __nv_bfloat16); +#endif + +INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, float, float); +INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, half, half); +#ifdef ENABLE_BF16 +INSTANTIATE_GENERIC_ACTIVATION(ReluActivation, __nv_bfloat16, __nv_bfloat16); +#endif + +INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, float, float); +INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, half, half); +#ifdef ENABLE_BF16 +INSTANTIATE_GENERIC_ACTIVATION(SiluActivation, __nv_bfloat16, __nv_bfloat16); +#endif + +INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, float); +INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, half, half); +INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, half); +#ifdef ENABLE_BF16 +INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, __nv_bfloat16, __nv_bfloat16); +INSTANTIATE_GENERIC_ACTIVATION(IdentityActivation, float, __nv_bfloat16); +#endif +#undef INSTANCIATE_GENERIC_ACTIVATION + +template +__global__ void add_bias_tanh(T* out, const T* __restrict bias, int m, int n) +{ + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + T val = out[id]; + if (bias != nullptr) { + val = val + ldg(&bias[id % n]); + } + out[id] = tanhf(val); + } +} + +template<> +__global__ void add_bias_tanh(half* out, const half* __restrict bias, int m, int n) +{ + half2* out_ptr = (half2*)out; + const half2* bias_ptr = (half2*)bias; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + half2 val = out_ptr[id]; + if (bias != nullptr) { + val = val + __ldg(&bias_ptr[id % n]); + } + val.x = tanhf(val.x); + val.y = tanhf(val.y); + out_ptr[id] = val; + } +} + +#ifdef ENABLE_BF16 +template<> +__global__ void add_bias_tanh(__nv_bfloat16* out, const __nv_bfloat16* __restrict bias, int m, int n) +{ + __nv_bfloat162* out_ptr = (__nv_bfloat162*)out; + const __nv_bfloat162* bias_ptr = (__nv_bfloat162*)bias; + + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { + __nv_bfloat162 val = out_ptr[id]; + if (bias != nullptr) { + val = bf16hadd2(val, ldg(&bias_ptr[id % n])); + } + val.x = tanhf(val.x); + val.y = tanhf(val.y); + out_ptr[id] = val; + } +} +#endif + +template +void invokeAddBiasTanh(T* out, const T* bias, const int m, const int n, cudaStream_t stream) +{ + const int data_type_factor = 4 / sizeof(T); // 1 for fp32, 2 for fp16 and bf16 + dim3 block, grid; + if (n / 4 / data_type_factor <= 1024) { + block.x = n / 4 / data_type_factor; + grid.x = m; + } + else { + block.x = 1024; + grid.x = ceil(m * n / 1024.); + } + add_bias_tanh<<>>(out, bias, m, n / data_type_factor); +} + +template void invokeAddBiasTanh(float* out, const float* bias, const int m, const int n, cudaStream_t stream); +template void invokeAddBiasTanh(half* out, const half* bias, const int m, const int n, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void +invokeAddBiasTanh(__nv_bfloat16* out, const __nv_bfloat16* bias, const int m, const int n, cudaStream_t stream); +#endif + +template +__global__ void addBiasGeluV2(T2* out, + const T2* __restrict bias, + const int* ia3_tasks, + const T2* ia3_weights, + const int size, + const int* padding_offset, + const int seq_len) +{ + const bool with_ia3 = ia3_tasks != nullptr; + for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) { + T2 val = out[id]; + if (bias != nullptr) { + T2 reg_bias = ldg(&bias[id % N]); + val = hadd2(val, reg_bias); + } + val = GeluActivation::apply(val); + if (with_ia3) { + const int word_id = id / N; + const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id]; + const int batch_id = (word_id + offset) / seq_len; + const int task = ia3_tasks[batch_id]; + val = val * ia3_weights[task * N + (id % N)]; + } + out[id] = val; + } +} + +template +__global__ void addBiasGeluV3(T2* out, + const T2* __restrict bias, + const int* ia3_tasks, + const T2* ia3_weights, + const int size, + const int* padding_offset, + const int seq_len) +{ + const bool with_ia3 = ia3_tasks != nullptr; + T2 buffer[ELEMENT_PER_ROUND]; + T2 tmp_bias[ELEMENT_PER_ROUND]; + for (int id = blockIdx.x * blockDim.x * ELEMENT_PER_ROUND + threadIdx.x * ELEMENT_PER_ROUND; id < size; + id += blockDim.x * gridDim.x * ELEMENT_PER_ROUND) { +#pragma unroll + for (int i = 0; i < ELEMENT_PER_ROUND; i++) { + buffer[i] = out[id + i]; + if (bias != nullptr) { + tmp_bias[i] = ldg(&bias[(id + i) % N]); + } + } +#pragma unroll + for (int i = 0; i < ELEMENT_PER_ROUND; i++) { + if (bias != nullptr) { + buffer[i] = hadd2(buffer[i], tmp_bias[i]); + } + buffer[i] = GeluActivation::apply(buffer[i]); + if (with_ia3) { + const int word_id = (id + i) / N; + const int offset = padding_offset == nullptr ? 0 : padding_offset[word_id]; + const int batch_id = (word_id + offset) / seq_len; + const int task = ia3_tasks[batch_id]; + buffer[i] = buffer[i] * ia3_weights[task * N + ((id + i) % N)]; + } + out[id + i] = buffer[i]; + } + } +} + +#define ADD_BIAS_GELU(HALF_N, ELEMENT_PER_ROUND) \ + case HALF_N: \ + if (ELEMENT_PER_ROUND > 1) { \ + grid.x = grid.x / ELEMENT_PER_ROUND; \ + addBiasGeluV3<<>>( \ + (T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n, padding_offset, seq_len); \ + } \ + else { \ + addBiasGeluV2<<>>( \ + (T2*)out, (const T2*)bias, ia3_tasks, (T2*)ia3_weights, m * half_n, padding_offset, seq_len); \ + } \ + break; + +template +void invokeAddBiasGeluV2(T* out, + const T* bias, + const int* ia3_tasks, + const T* ia3_weights, + const int* padding_offset, + const int seq_len, + const int m, + const int n, + cudaStream_t stream) +{ + if (n % 2 == 0 && sizeof(T) == 2) { + const int half_n = n / 2; + dim3 block, grid; + block.x = std::min(half_n, 512); + grid.x = (m * half_n + (block.x - 1)) / block.x; + using T2 = typename TypeConverter::Type; + + if (grid.x >= 512) { + switch (half_n) { + ADD_BIAS_GELU(256, 1) + ADD_BIAS_GELU(512, 1) + ADD_BIAS_GELU(1024, 1) + ADD_BIAS_GELU(1536, 1) + ADD_BIAS_GELU(2048, 1) + ADD_BIAS_GELU(4096, 2) + ADD_BIAS_GELU(8192, 2) + ADD_BIAS_GELU(16384, 2) + ADD_BIAS_GELU(24576, 2) + ADD_BIAS_GELU(40960, 4) + default: + invokeGenericActivation(out, + bias, + (T*)nullptr, + (T*)nullptr, + ia3_tasks, + ia3_weights, + m, + n, + 0, + (float*)nullptr, + (float*)nullptr, + padding_offset, + seq_len, + stream); + break; + } + } + else { + switch (half_n) { + ADD_BIAS_GELU(256, 1) + ADD_BIAS_GELU(512, 1) + ADD_BIAS_GELU(1024, 1) + ADD_BIAS_GELU(1536, 1) + ADD_BIAS_GELU(2048, 1) + ADD_BIAS_GELU(4096, 1) + ADD_BIAS_GELU(8192, 2) + ADD_BIAS_GELU(16384, 2) + ADD_BIAS_GELU(24576, 2) + ADD_BIAS_GELU(40960, 2) + default: + invokeGenericActivation(out, + bias, + (T*)nullptr, + (T*)nullptr, + ia3_tasks, + ia3_weights, + m, + n, + 0, + (float*)nullptr, + (float*)nullptr, + padding_offset, + seq_len, + stream); + break; + } + } + } + else { + invokeGenericActivation(out, + bias, + (T*)nullptr, + (T*)nullptr, + ia3_tasks, + ia3_weights, + m, + n, + 0, + (float*)nullptr, + (float*)nullptr, + padding_offset, + seq_len, + stream); + } +} + +#undef ADD_BIAS_GELU + +template void invokeAddBiasGeluV2(float* out, + const float* bias, + const int* ia3_tasks, + const float* ia3_weights, + const int* padding_offset, + const int seq_len, + const int m, + const int n, + cudaStream_t stream); +template void invokeAddBiasGeluV2(half* out, + const half* bias, + const int* ia3_tasks, + const half* ia3_weights, + const int* padding_offset, + const int seq_len, + const int m, + const int n, + cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeAddBiasGeluV2(__nv_bfloat16* out, + const __nv_bfloat16* bias, + const int* ia3_tasks, + const __nv_bfloat16* ia3_weights, + const int* padding_offset, + const int seq_len, + const int m, + const int n, + cudaStream_t stream); +#endif // ENABLE_BF16 + +template +__global__ void sigmoid_kernel(T* data, const int size, const float scale) +{ + const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; + if (index < size) { + float val = cuda_cast(data[index]); + val = 1.0f / (1.0f + exp(-val)) * scale; + data[index] = T(val); + } +} + +template<> +__global__ void sigmoid_kernel(half2* data, const int size, const float scale) +{ + const int index = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x + threadIdx.x; + if (index < size / 2) { + half2 val = data[index]; + float2 val_float2 = cuda_cast(val); + val_float2.x = 1.0f / (1.0f + exp(-val_float2.x)) * scale; + val_float2.y = 1.0f / (1.0f + exp(-val_float2.y)) * scale; + data[index] = cuda_cast(val_float2); + } +} + +template +void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream) +{ + if (std::is_same::value || (size % 2 != 0)) { + dim3 block(128); + dim3 grid((size + 127) / 128); + sigmoid_kernel<<>>(data, size, scale); + } + else { + dim3 block(128); + dim3 grid((size + 255) / 256); + sigmoid_kernel<<>>((half2*)data, size, scale); + } +} + +template void invokeSigmoid(float* data, const int size, const float scale, cudaStream_t stream); +template void invokeSigmoid(half* data, const int size, const float scale, cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/activation_kernels.h b/src/fastertransformer/kernels/activation_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..e4c561e483921708cf62f6997c970cbdbe4299f6 --- /dev/null +++ b/src/fastertransformer/kernels/activation_kernels.h @@ -0,0 +1,110 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +namespace fastertransformer { + +// clang-format off +template struct GeluActivation; +template struct ReluActivation; +template struct SiluActivation; +template struct IdentityActivation; +// clang-format on + +template class Activation, typename T, typename BT> +void invokeGenericActivation(T* out, + const BT* bias, + const T* gated_weights, + const BT* gated_bias, + const int* ia3_tasks, + const T* ia3_weights, + const int m, + const int n, + const int int8_mode, + const float* activation_in, + const float* activation_out, + const int* padding_offset, + const int seq_len, + cudaStream_t stream); + +template class Activation, typename T, typename BT> +void invokeGenericActivation(T* out, + const BT* bias, + const T* gated_weights, + const BT* gated_bias, + const int* ia3_tasks, + const T* ia3_weights, + const int m, + const int n, + const int int8_mode, + const float* activation_in, + const float* activation_out, + cudaStream_t stream) +{ + invokeGenericActivation(out, + bias, + gated_weights, + gated_bias, + ia3_tasks, + ia3_weights, + m, + n, + int8_mode, + activation_in, + activation_out, + (const int*)nullptr, + 0, + stream); +} + +template +void invokeAddBiasGeluV2(T* out, + const T* bias, + const int* ia3_tasks, + const T* ia3_weights, + const int* padding_offset, + const int seq_len, + const int m, + const int n, + cudaStream_t stream); + +template +void invokeAddBias(T* out, T const* bias, const int m, const int n, cudaStream_t stream) +{ + invokeGenericActivation( + out, bias, nullptr, nullptr, nullptr, nullptr, m, n, 0, nullptr, nullptr, stream); +} + +template +void invokeAddBiasGeluV2( + T* out, const T* bias, const int* ia3_tasks, const T* ia3_weights, const int m, const int n, cudaStream_t stream) +{ + invokeAddBiasGeluV2(out, bias, ia3_tasks, ia3_weights, nullptr, 0, m, n, stream); +} + +template +void invokeAddBiasTanh(T* out, const T* bias, const int m, const int n, cudaStream_t stream); + +template +void invokeSigmoid(T* data, const int size, const float scale, cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/ban_bad_words.cu b/src/fastertransformer/kernels/ban_bad_words.cu new file mode 100644 index 0000000000000000000000000000000000000000..e5fb77f004ff236f6c935b194b5e5be13ddb19c5 --- /dev/null +++ b/src/fastertransformer/kernels/ban_bad_words.cu @@ -0,0 +1,164 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/kernels/ban_bad_words.h" +#include "src/fastertransformer/utils/cuda_utils.h" + +namespace fastertransformer { + +template +__global__ void ban_bad_words(T* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int beam_width, + const int* bad_words, + size_t bad_words_len, + bool share_words, + int id_offset, + int vocab_size_padded, + size_t step) +{ + const int id = blockIdx.x * blockDim.x + threadIdx.x; + const int batch_idx = blockIdx.y / beam_width; + const int beam_idx = blockIdx.y % beam_width; + + const int* base_bad_words = share_words ? bad_words : bad_words + batch_idx * 2 * bad_words_len; + const int* base_bad_words_offsets = base_bad_words + bad_words_len; + + if (id >= bad_words_len || base_bad_words_offsets[id] < 0) { + return; + } + + const int item_end = base_bad_words_offsets[id]; + const int item_start = (id > 0) ? base_bad_words_offsets[id - 1] : 0; + const int item_size = item_end - item_start; + + /* The single-token case unconditionally bans the token */ + bool should_ban = item_size == 1; + + /* Multi-token case and enough previously generated tokens to look for a match */ + if (item_size > 1 && step >= item_size - 1) { + should_ban = true; + int parent_id = beam_idx; + const bool gather_beam = beam_width > 1; + + for (int token_idx = item_size - 2; token_idx >= 0; token_idx--) { + const int previous_token = output_ids_buf[(step - (item_size - 1) + token_idx) * batch_size * beam_width + + id_offset + batch_idx * beam_width + parent_id]; + + if (previous_token != base_bad_words[item_start + token_idx]) { + should_ban = false; + break; + } + if (gather_beam) { + parent_id = parent_ids_buf[(step - (item_size - 1) + token_idx) * beam_width * batch_size + id_offset + + batch_idx * beam_width + parent_id]; + + if (parent_id < 0 || parent_id >= beam_width) { + should_ban = false; + break; + } + } + } + } + + if (should_ban) { + int banned_token = base_bad_words[item_end - 1]; + if (0 < banned_token && banned_token < vocab_size_padded) { + logits[batch_idx * beam_width * vocab_size_padded + beam_idx * vocab_size_padded + banned_token] = + static_cast(-INFINITY); + } + } +} + +template +void invokeBanBadWords(T* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, + cudaStream_t stream) +{ + dim3 block, grid; + block.x = min(((bad_words_len + 32 - 1) / 32) * 32, 256UL); + grid.x = (bad_words_len + block.x - 1) / block.x; + grid.y = local_batch_size * beam_width; + + ban_bad_words<<>>(logits, + output_ids_buf, + parent_ids_buf, + batch_size, + beam_width, + bad_words, + bad_words_len, + share_words, + id_offset, + vocab_size_padded, + step); + sync_check_cuda_error(); +} + +template void invokeBanBadWords(half* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, + cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeBanBadWords(__nv_bfloat16* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, + cudaStream_t stream); +#endif +template void invokeBanBadWords(float* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, + cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/ban_bad_words.h b/src/fastertransformer/kernels/ban_bad_words.h new file mode 100644 index 0000000000000000000000000000000000000000..6c6b31ac407566e79408faa245745b880d292e72 --- /dev/null +++ b/src/fastertransformer/kernels/ban_bad_words.h @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace fastertransformer { + +template +void invokeBanBadWords(T* logits, + const int* output_ids_buf, + const int* parent_ids_buf, + int batch_size, + int local_batch_size, + int beam_width, + const int* bad_words, + bool share_words, + size_t bad_words_len, + int id_offset, + int vocab_size_padded, + size_t step, + cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/beam_search_penalty_kernels.cu b/src/fastertransformer/kernels/beam_search_penalty_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..64c746ef08f7f68667e66c99e4f89c0d1cf56598 --- /dev/null +++ b/src/fastertransformer/kernels/beam_search_penalty_kernels.cu @@ -0,0 +1,313 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "src/fastertransformer/kernels/beam_search_penalty_kernels.h" +#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" + +namespace fastertransformer { + +template +__global__ void add_bias_temperature(T* logits, + const T* bias, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const float temperature) +{ + int tid = threadIdx.x; + int bid = blockIdx.x; + int bbid = blockIdx.y; + + logits += bbid * vocab_size_padded; + + const T MASK_VAL = (std::is_same::value) ? -HALF_FLT_MAX : -FLT_MAX; + const T inv_temp = static_cast(1.0f / (temperature + 1e-6f)); + for (int i = tid + bid * blockDim.x; i < vocab_size_padded; i += blockDim.x * gridDim.x) { + if (i < vocab_size) { + T bias_val = bias == nullptr ? (T)(0.0f) : bias[i]; + logits[i] = (logits[i] + bias_val) * inv_temp; + } + else { + logits[i] = MASK_VAL; + } + } +} + +template<> +__global__ void add_bias_temperature(half2* logits, + const half2* bias, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const float temperature) +{ + assert(vocab_size % 2 == 0); + assert(vocab_size_padded % 2 == 0); + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int bbid = blockIdx.y; + + const half2 mask_val = __float2half2_rn(-HALF_FLT_MAX); + const half2 inv_temp = __float2half2_rn(1.0f / (temperature + 1e-6f)); + + const int half_vocab_size = vocab_size / 2; + const int half_vocab_size_padded = vocab_size_padded / 2; + + logits += bbid * half_vocab_size_padded; + for (int index = tid + bid * blockDim.x; index < half_vocab_size_padded; index += blockDim.x * gridDim.x) { + int vocab_idx = index % half_vocab_size_padded; + half2 logit = vocab_idx < half_vocab_size ? __ldg(&logits[index]) : mask_val; + if (vocab_idx < half_vocab_size) { + if (bias != nullptr) { + logit = __hadd2(logit, bias[vocab_idx]); + } + logit = __hmul2(logit, inv_temp); + } + logits[index] = logit; + } +} + +template +__global__ void apply_repetition_penalty(T* logits, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int step, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, + const int max_input_length, + const float repetition_penalty) +{ + assert(step > 0); + + const int tid = threadIdx.x; + const int bbid = blockIdx.x; + const int batch_id = bbid / beam_width; + const int bbsize = batch_size * beam_width; + + logits += bbid * vocab_size_padded; + extern __shared__ char sbuf[]; + T* penalty_logits = reinterpret_cast(sbuf); + // prevent misaligment when sizeof(T) = 2 + int* penalty_indices = reinterpret_cast(sbuf + (sizeof(T) * step + 31) / 32 * 32); + const int input_length = (input_lengths != nullptr) ? input_lengths[bbid] : max_input_length; + if (tid == 0) { + T repet_penalty = static_cast(repetition_penalty); + int prev_id = current_ids[bbid]; + T prev_logit = logits[prev_id]; + penalty_indices[step - 1] = prev_id; + + if (IS_ADDITIVE) { + penalty_logits[step - 1] = prev_logit - repet_penalty; + } + else { + penalty_logits[step - 1] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty; + } + if (step > 1) { + int parent_beam = bbid % beam_width; + for (int i = step - 2; i >= 0; --i) { + // Skip the padded tokens. + if (i >= input_length && i < max_input_length) { + continue; + } + parent_beam = parent_ids[i * bbsize + batch_id * beam_width + parent_beam]; + prev_id = previous_ids[i * bbsize + batch_id * beam_width + parent_beam]; + prev_logit = logits[prev_id]; + penalty_indices[i] = prev_id; + if (IS_ADDITIVE) { + penalty_logits[i] = prev_logit - repet_penalty; + } + else { + penalty_logits[i] = prev_logit > T(0) ? prev_logit / repet_penalty : prev_logit * repet_penalty; + } + } + } + } + __syncthreads(); + for (int i = tid; i < step; i += blockDim.x) { + if (i >= input_length && i < max_input_length) { + continue; + } + logits[penalty_indices[i]] = penalty_logits[i]; + } +} + +template +__global__ void apply_min_length_penalty(T* logits, + const int min_length, + const int* end_ids, + const int* sequence_lengths, + const int max_input_length, + const int beam_width, + const int vocab_size_padded) +{ + int bbid = threadIdx.x + blockIdx.x * blockDim.x; // batch-beam index + int bid = bbid / beam_width; // batch index + // We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1, + // which is equal to the length of k/v caches. + if (sequence_lengths[bbid] + 1 - max_input_length < min_length) { + T mask_val = (std::is_same::value) ? -HALF_FLT_MAX : -FLT_MAX; + logits[bbid * vocab_size_padded + end_ids[bid]] = mask_val; + } +} + +template +void invokeAddBiasApplyPenalties(int step, + T* logits, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, + const int* sequence_lengths, + const T* bias, + const int ite, + const int max_input_length, + const int local_batch_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int* end_ids, + const float temperature, + const float repetition_penalty, + const RepetitionPenaltyType repetition_penalty_type, + const int min_length, + cudaStream_t stream) +{ + if (bias != nullptr || temperature != 1.0f || vocab_size != vocab_size_padded) { + dim3 block(512); + if (std::is_same::value && vocab_size % 2 == 0 && vocab_size_padded % 2 == 0) { + dim3 grid((vocab_size_padded / 2 + block.x - 1) / block.x, beam_width * local_batch_size); + add_bias_temperature<<>>(reinterpret_cast(logits), + reinterpret_cast(bias), + batch_size, + beam_width, + vocab_size, + vocab_size_padded, + temperature); + } + else { + dim3 grid((vocab_size_padded + block.x - 1) / block.x, beam_width * local_batch_size); + add_bias_temperature<<>>( + logits, bias, batch_size, beam_width, vocab_size, vocab_size_padded, temperature); + } + } + + if (repetition_penalty_type != RepetitionPenaltyType::None && step > 0) { + if (repetition_penalty != getDefaultPenaltyValue(repetition_penalty_type)) { + size_t smem_size = (sizeof(T) * step + 31) / 32 * 32 + sizeof(int) * step; + dim3 block(256); + dim3 grid(beam_width * local_batch_size); + if (repetition_penalty_type == RepetitionPenaltyType::Multiplicative) { + apply_repetition_penalty + <<>>(logits, + batch_size, + beam_width, + vocab_size, + vocab_size_padded, + step, + current_ids, + previous_ids, + // TODO(jaedeokk): + // Remove (+ite ...) by getting parent_ids with offset + // and then remove 'ite' argument from the function. + parent_ids + ite * beam_width * local_batch_size, + input_lengths, + max_input_length, + repetition_penalty); + } + else if (repetition_penalty_type == RepetitionPenaltyType::Additive) { + apply_repetition_penalty + <<>>(logits, + batch_size, + beam_width, + vocab_size, + vocab_size_padded, + step, + current_ids, + previous_ids, + parent_ids + ite * beam_width * local_batch_size, + input_lengths, + max_input_length, + repetition_penalty); + } + } + } + + if (step - max_input_length < min_length) { + FT_CHECK_WITH_INFO(sequence_lengths != nullptr, "Need sequence_lengths to apply min length penlaty"); + FT_CHECK_WITH_INFO(end_ids != nullptr, "Need end_id to apply min length penlaty"); + + const int block_size = min(local_batch_size * beam_width, 1024); + const int grid_size = (local_batch_size * beam_width + block_size - 1) / block_size; + apply_min_length_penalty<<>>( + logits, min_length, end_ids, sequence_lengths, max_input_length, beam_width, vocab_size_padded); + } +} + +template void invokeAddBiasApplyPenalties(int step, + float* logits, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, + const int* sequence_lengths, + const float* bias, + const int ite, + const int max_input_length, + const int local_batch_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int* end_ids, + const float temperature, + const float repetition_penalty, + const RepetitionPenaltyType repetition_penalty_type, + const int min_length, + cudaStream_t stream); + +template void invokeAddBiasApplyPenalties(int step, + half* logits, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, + const int* sequence_lengths, + const half* bias, + const int ite, + const int max_input_length, + const int local_batch_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int* end_ids, + const float temperature, + const float repetition_penalty, + const RepetitionPenaltyType repetition_penalty_type, + const int min_length, + cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/beam_search_penalty_kernels.h b/src/fastertransformer/kernels/beam_search_penalty_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..ae67f9654c752bb870ec35248de7739bd3db3792 --- /dev/null +++ b/src/fastertransformer/kernels/beam_search_penalty_kernels.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "src/fastertransformer/kernels/penalty_types.h" +#include "src/fastertransformer/utils/cuda_utils.h" + +namespace fastertransformer { + +template +void invokeAddBiasApplyPenalties(int step, + T* logits, + const int* current_ids, + const int* previous_ids, + const int* parent_ids, + const int* input_lengths, + const int* sequence_lengths, + const T* bias, + const int ite, + const int max_input_length, + const int local_batch_size, + const int batch_size, + const int beam_width, + const int vocab_size, + const int vocab_size_padded, + const int* end_ids, + const float temperature, + const float repetition_penalty, + const RepetitionPenaltyType repetition_penalty_type, + const int min_length, + cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/beam_search_topk_kernels.cu b/src/fastertransformer/kernels/beam_search_topk_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..fcaf644b0f320fca3eae9cf8191b53b1a6fa90bc --- /dev/null +++ b/src/fastertransformer/kernels/beam_search_topk_kernels.cu @@ -0,0 +1,845 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#elif (CUDART_VERSION >= 11050) +#include +#else +#include "3rdparty/cub/cub.cuh" +#endif + +#include "src/fastertransformer/kernels/beam_search_topk_kernels.h" +#include "src/fastertransformer/kernels/reduce_kernel_utils.cuh" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" +#include "src/fastertransformer/utils/cuda_utils.h" +#include "src/fastertransformer/utils/logger.h" + +namespace fastertransformer { + +template +__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty) +{ + // score = log(prob) / (length)^length_penalty. + if (length_penalty == 0.0f || length == 1) { + return log_prob; + } + return log_prob / static_cast(powf((float)length, length_penalty)); +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ void beam_topK_kernel(const T* log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const bool* finished, + const int* sequence_lengths, + const int vocab_size, + T diversity_rate, + float length_penalty) +{ + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int thread_id = threadIdx.x; + int block_id = blockIdx.x; // batch beam index. + TopK partial; + + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + +#pragma unroll + for (int i = 0; i < MAX_K; ++i) { + partial.p[i] = -1; + partial.u[i] = -MAX_T_VAL; + } + +#pragma unroll + for (int elem_id = thread_id; elem_id < vocab_size; elem_id += THREADBLOCK_SIZE) { + int index = elem_id + block_id * vocab_size; + T score = length_penalty == 0.0f ? log_probs[index] : + apply_length_penalty(log_probs[index], + finished[block_id] ? sequence_lengths[block_id] : + sequence_lengths[block_id] + 1, + length_penalty); + partial.insert(score, index); + } + + TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (thread_id == 0) { + int index = block_id * MAX_K; + +#pragma unroll + for (int i = 0; i < MAX_K; ++i) { + topk_tmp_id_buf[index + i] = total.p[i]; + topk_tmp_val_buf[index + i] = total.u[i] + diversity_rate * (T)i; + } + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void batch_topK_kernel(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) +{ + int thread_id = threadIdx.x; + int block_id = blockIdx.x; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + TopK partial; + if (thread_id == 0) { + for (int i = 0; i < MAX_K; ++i) { + partial.p[i] = -1; + partial.u[i] = -MAX_T_VAL; + } + + int index = block_id * MAX_K * MAX_K; + for (int i = 0; i < MAX_K * MAX_K; i++) { + partial.insert((T)topk_tmp_val_buf[index + i], topk_tmp_id_buf[index + i]); + } + + index = block_id * MAX_K; + for (int i = 0; i < MAX_K; i++) { + id_buf[index + i] = partial.p[i]; + } + } +} + +template +__launch_bounds__(THREADBLOCK_SIZE) __global__ + void batch_topK_kernel_v2(int* topk_tmp_id_buf, T* topk_tmp_val_buf, int* id_buf) +{ + typedef cub::BlockReduce, THREADBLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int tid = threadIdx.x; + int bid = blockIdx.x; + TopK partial; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + +#pragma unroll + for (int i = 0; i < MAX_K; ++i) { + partial.p[i] = -1; + partial.u[i] = -MAX_T_VAL; + } + + int ite = MAX_K * MAX_K / THREADBLOCK_SIZE; +#pragma unroll + for (int i = 0; i < ite; i++) { + int index = bid * MAX_K * MAX_K + i * THREADBLOCK_SIZE + tid; + partial.insert((T)topk_tmp_val_buf[index], topk_tmp_id_buf[index]); + } + + TopK total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op); + + if (tid == 0) { +#pragma unroll + for (int i = 0; i < MAX_K; i++) { + id_buf[bid * MAX_K + i] = total.p[i]; + } + } +} + +template +__global__ void topk_stage_1_opt3(const T* __restrict log_probs, + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const bool* finished, + const int* sequence_lengths, + const int k, + const int vocab_size, + const float length_penalty, + const int* end_ids) +{ + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index) + const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam + const int tmp_log_buf_index = row_id * vocab_size; + const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; + TopK_2 partial; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + + if (finished != nullptr && finished[row_id] == true) { + if (tid < k) { + const int index = tmp_topk_buf_index + tid; + if (block_lane == 0 && tid == 0) { + const int end_id = end_ids[row_id / k]; + topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; + topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id]; + } + else { + topk_tmp_id_buf[index] = -1; + topk_tmp_val_buf[index] = -MAX_T_VAL; + } + } + return; + } + + for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; + elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { + int index = elem_id + tmp_log_buf_index; + tmp_log_probs[index] = log_probs[index]; + } + + for (int ite = 0; ite < k; ite++) { + partial.init(); +#pragma unroll + for (int elem_id = tid + block_lane * BLOCK_SIZE_; elem_id < vocab_size; + elem_id += BLOCK_SIZE_ * BLOCKS_PER_BEAM_) { + int index = elem_id + tmp_log_buf_index; + partial.insert(tmp_log_probs[index], index); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if (tid == 0) { + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = total.p; + topk_tmp_val_buf[index] = total.u; + tmp_log_probs[total.p] = -MAX_T_VAL; + } + __syncthreads(); + } +} + +template +__global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + BeamHypotheses beam_hyps, + const int* end_ids, + const int vocab_size, + const int k) +{ + const int size = k * k * BLOCKS_PER_BEAM_; + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + + typedef cub::BlockReduce, BLOCK_SIZE_> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char array[]; + T* s_val = topk_tmp_val_buf + batch_id * size; + int* s_id = (int*)(array); + + __shared__ int selected_beams; + __shared__ bool is_stop; + + if (tid == 0) { + selected_beams = 0; + is_stop = false; + } + __syncthreads(); + if (beam_hyps.num_beams != nullptr) { + const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; + if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) { + // initialize the buffer + beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; + } + else if (beam_hyps.num_beams[global_batch_idx] == k) { + return; + } + } + + TopK_2 partial; + + // In some cases, we may encounter k finished sentences, but scores are bad. So, the max iteration + // is 2*k here + for (int ite = 0; ite < 2 * k; ite++) { + partial.init(); +#pragma unroll + for (int i = tid; i < size; i += BLOCK_SIZE_) { + partial.insert(s_val[i], i); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if (tid == 0) { + if (beam_hyps.num_beams != nullptr + && topk_tmp_id_buf[batch_id * size + total.p] % vocab_size == end_ids[batch_id]) { + // if beam_token does not belong to top num_beams tokens, it should not be added. Refer from + // https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257 + if (ite >= k) { + s_val[total.p] = -MAX_T_VAL; + } + else { + const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; + const float normed_score = + apply_length_penalty(s_val[total.p], beam_hyps.step, beam_hyps.length_penalty); + const int num_beam = beam_hyps.num_beams[global_batch_idx]; + int beam_idx = num_beam; + // If there are beam_width finished sentences, check that the score of selected candidatet + // is higher than min_normed_score or not. If current score is better, replace worst one + // and update the min_normed_score. + if (num_beam == k) { + if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) { + // end the tracing and exist this for loop + selected_beams = k; + is_stop = true; + break; + } + else { + // find the beam index which's score = min_normed_score, erase it. + for (int j = 0; j < k; j++) { + if (beam_hyps.normed_scores[global_batch_idx * k + j] + == beam_hyps.min_normed_scores[global_batch_idx]) { + beam_idx = j; + beam_hyps.num_beams[global_batch_idx]--; + + beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; + beam_hyps.normed_scores[global_batch_idx * k + j] = normed_score; + for (int l = 0; l < k; l++) { + beam_hyps.min_normed_scores[global_batch_idx] = + min(beam_hyps.min_normed_scores[global_batch_idx], + beam_hyps.normed_scores[global_batch_idx * k + l]); + } + break; + } + } + } + } + const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) + * (beam_hyps.max_seq_len); + beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; + + int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; + for (int j = beam_hyps.step - 1; j >= 0; j--) { + const int src_idx = j * beam_hyps.batch_size * k + + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; + + beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; + prev_id = beam_hyps.parent_ids_src[src_idx]; + } + const int tgt_beam_idx = global_batch_idx * k + beam_idx; + beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; + beam_hyps.normed_scores[tgt_beam_idx] = normed_score; + beam_hyps.min_normed_scores[global_batch_idx] = + min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); + + s_val[total.p] = -MAX_T_VAL; + + beam_hyps.num_beams[global_batch_idx]++; + } + } + else { + s_id[selected_beams] = total.p; + s_val[total.p] = -MAX_T_VAL; + selected_beams++; + } + } + __syncthreads(); + if (selected_beams >= k) { + break; + } + } + if (tid < k && is_stop == false) { + ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]]; + } +} + +template +__global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, + T* tmp_log_probs, + int* topk_tmp_id_buf, + T* topk_tmp_val_buf, + const bool* finished, + const int* sequence_lengths, + const int k, + const int vocab_size, + const float length_penalty) +{ + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs + const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam + const int tmp_log_buf_index = row_id * vocab_size; + const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k; + TopK_2 partial; + + for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) { + int index = elem_id + tmp_log_buf_index; + tmp_log_probs[index] = log_probs[index]; + } + + for (int ite = 0; ite < k; ite++) { + partial.init(); +#pragma unroll + for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; + elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) { + int index = elem_id + tmp_log_buf_index; + partial.insert(tmp_log_probs[index], index); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if (tid == 0) { + const int index = tmp_topk_buf_index + ite; + topk_tmp_id_buf[index] = total.p; + topk_tmp_val_buf[index] = total.u; + tmp_log_probs[total.p] = -MAX_T_VAL; + } + __syncthreads(); + } +} + +template +__global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, + T* topk_tmp_val_buf, + int* ids, + BeamHypotheses beam_hyps, + const int* end_ids, + const int k, + const int vocab_size) +{ + const int size = k * k * BLOCKS_PER_BEAM; + const int tid = threadIdx.x; + const int batch_id = blockIdx.x; + const bool IS_FP16 = std::is_same::value; + const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; + + typedef cub::BlockReduce, BLOCK_SIZE> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + extern __shared__ char array[]; + T* s_val = topk_tmp_val_buf + batch_id * size; + int* s_id = (int*)(array); + + __shared__ int selected_beams; + __shared__ bool is_stop; + + if (tid == 0) { + selected_beams = 0; + is_stop = false; + } + __syncthreads(); + if (beam_hyps.num_beams != nullptr) { + const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; + if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) { + beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; + } + else if (beam_hyps.num_beams[global_batch_idx] == k) { + return; + } + } + + TopK_2 partial; + + // In some cases, we may encounter k finished sentences, but scores are bad. So, the max iteration + // is 2*k here + for (int ite = 0; ite < 2 * k; ite++) { + partial.init(); +#pragma unroll + for (int i = tid; i < size; i += BLOCK_SIZE) { + partial.insert(s_val[i], i); + } + + TopK_2 total = BlockReduce(temp_storage).Reduce(partial, reduce_topk_op_2); + + if (tid == 0) { + if (beam_hyps.num_beams != nullptr + && topk_tmp_id_buf[batch_id * size + total.p] % vocab_size == end_ids[batch_id]) { + // if beam_token does not belong to top num_beams tokens, it should not be added. Refer from + // https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/generation_beam_search.py#L257 + if (ite >= k) { + s_val[total.p] = -MAX_T_VAL; + } + else { + const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; + const float normed_score = + apply_length_penalty(s_val[total.p], beam_hyps.step, beam_hyps.length_penalty); + const int num_beam = beam_hyps.num_beams[global_batch_idx]; + int beam_idx = num_beam; + // If there are beam_width finished sentences, check that the score of selected candidatet + // is higher than min_normed_score or not. If current score is better, replace worst one + // and update the min_normed_score. + if (num_beam == k) { + if (normed_score < beam_hyps.min_normed_scores[global_batch_idx]) { + // end the tracing and exist this for loop + selected_beams = k; + is_stop = true; + break; + } + else { + // find the beam index which's score = min_normed_score, erase it. + for (int j = 0; j < k; j++) { + if (beam_hyps.normed_scores[global_batch_idx * k + j] + == beam_hyps.min_normed_scores[global_batch_idx]) { + beam_idx = j; + beam_hyps.num_beams[global_batch_idx]--; + + beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; + beam_hyps.normed_scores[global_batch_idx * k + j] = normed_score; + for (int l = 0; l < k; l++) { + beam_hyps.min_normed_scores[global_batch_idx] = + min(beam_hyps.min_normed_scores[global_batch_idx], + beam_hyps.normed_scores[global_batch_idx * k + l]); + } + break; + } + } + } + } + const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) + * (beam_hyps.max_seq_len); + beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; + + int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; + for (int j = beam_hyps.step - 1; j >= 0; j--) { + const int src_idx = j * beam_hyps.batch_size * k + + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; + + beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; + prev_id = beam_hyps.parent_ids_src[src_idx]; + } + const int tgt_beam_idx = global_batch_idx * k + beam_idx; + beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; + beam_hyps.normed_scores[tgt_beam_idx] = normed_score; + beam_hyps.min_normed_scores[global_batch_idx] = + min(beam_hyps.min_normed_scores[global_batch_idx], beam_hyps.normed_scores[tgt_beam_idx]); + + s_val[total.p] = -MAX_T_VAL; + + beam_hyps.num_beams[global_batch_idx]++; + } + } + else { + s_id[selected_beams] = total.p; + s_val[total.p] = -MAX_T_VAL; + selected_beams++; + } + } + __syncthreads(); + if (selected_beams >= k) { + break; + } + } + if (tid < k && is_stop == false) { + ids[batch_id * k + tid] = topk_tmp_id_buf[batch_id * size + s_id[tid]]; + } +} + +#define CASE_K_DIV(K, BLOCK_SIZE_1, BLOCK_SIZE_2) \ + case K: \ + beam_topK_kernel<<>>(log_probs, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + finished, \ + sequence_lengths, \ + vocab_size, \ + diversity_rate, \ + length_penalty); \ + if (K < 10) \ + batch_topK_kernel \ + <<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \ + else \ + batch_topK_kernel_v2<<>>(topk_tmp_id_buf, topk_tmp_val_buf, ids); \ + break; + +#define CASE_K(K, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \ + case K: \ + topk_stage_1_opt3 \ + <<>>(log_probs, \ + temp_log_probs, \ + topk_tmp_id_buf, \ + topk_tmp_val_buf, \ + finished, \ + sequence_lengths, \ + beam_width, \ + vocab_size, \ + length_penalty, \ + end_ids); \ + topk_stage_2_opt3 \ + <<>>( \ + topk_tmp_id_buf, topk_tmp_val_buf, ids, *beam_hyps, end_ids, vocab_size, beam_width); \ + sync_check_cuda_error(); \ + break; + +template +void invokeTopkBeamSearch(void* workspace, + size_t& workspace_size, + T* log_probs, + int* ids, + BeamHypotheses* beam_hyps, + const bool* finished, + const int* sequence_lengths, + const int batch_size, + const int beam_width, + const int vocab_size_padded_, + const T diversity_rate, + const float length_penalty, + const int* end_ids, + cudaStream_t stream) +{ + FT_LOG_DEBUG("%s", __PRETTY_FUNCTION__); + // log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a token. + const int vocab_size = vocab_size_padded_; + // Beam size should be less than or equal to vocab size. + assert(beam_width <= vocab_size); + // Beam search needs the sequence lengths of beams to apply length penalty. + assert(length_penalty == 0.0f || sequence_lengths != nullptr); + const int max_block_per_beam = 8; + int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float + int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int + int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float + + // prevent memory misaligned address + temp_log_probs_buf_size = (int)(ceil(temp_log_probs_buf_size / 4.)) * 4; + topk_tmp_ids_buf_size = (int)(ceil(topk_tmp_ids_buf_size / 4.)) * 4; + topk_tmp_val_buf_size = (int)(ceil(topk_tmp_val_buf_size / 4.)) * 4; + + if (workspace == nullptr) { + workspace_size = sizeof(float) * temp_log_probs_buf_size + sizeof(int) * topk_tmp_ids_buf_size + + sizeof(float) * topk_tmp_val_buf_size; + return; + } + else { + T* temp_log_probs = (T*)workspace; + int* topk_tmp_id_buf = (int*)(temp_log_probs + temp_log_probs_buf_size); + T* topk_tmp_val_buf = (T*)(topk_tmp_id_buf + topk_tmp_ids_buf_size); + if (diversity_rate == 0.0f) { + switch (beam_width) { + CASE_K(1, 128, 128, 8); + CASE_K(4, 128, 128, 8); + CASE_K(10, 128, 128, 8); + CASE_K(16, 128, 128, 5); + CASE_K(32, 256, 128, 1); + CASE_K(64, 256, 256, 1); + default: + topk_stage_1_opt2_general + <<>>(log_probs, + temp_log_probs, + topk_tmp_id_buf, + topk_tmp_val_buf, + finished, + sequence_lengths, + beam_width, + vocab_size, + length_penalty); + topk_stage_2_opt2_general + <<>>( + topk_tmp_id_buf, topk_tmp_val_buf, ids, *beam_hyps, end_ids, beam_width, vocab_size); + break; + } + } + else { + switch (beam_width) { + CASE_K_DIV(1, 256, 256); + CASE_K_DIV(4, 256, 256); + CASE_K_DIV(16, 256, 64); + CASE_K_DIV(32, 256, 64); + CASE_K_DIV(64, 256, 64); + default: + FT_CHECK_WITH_INFO(false, fmtstr("Topk kernel does not support beamwidth = %d \n", beam_width)); + break; + } + } + return; + } +} + +#undef CASE_K +#undef CASE_K_DIV + +template void invokeTopkBeamSearch(void* workspace, + size_t& workspace_size, + float* log_probs, + int* ids, + BeamHypotheses* beam_hyps, + const bool* finished, + const int* sequence_lengths, + const int batch_size, + const int beam_width, + const int vocab_size_padded_, + const float diversity_rate, + const float length_penalty, + const int* end_ids, + cudaStream_t stream); + +template +__global__ void tileEncoderResults(T* tiled_output, + int* tiled_sequence_length, + const T* output, + const int* sequence_length, + const uint batch_size, + const uint beam_width, + const uint d_model) +{ + if (blockIdx.x == 0) { + for (uint i = threadIdx.x; i < batch_size * beam_width; i += blockDim.x) { + tiled_sequence_length[i] = sequence_length[i / beam_width]; + } + } + + int tgt_offset = + blockIdx.x * gridDim.y * gridDim.z * d_model + blockIdx.y * gridDim.z * d_model + blockIdx.z * d_model; + int src_offset = blockIdx.x * gridDim.z * d_model + blockIdx.z * d_model; + for (uint i = threadIdx.x; i < d_model; i += blockDim.x) { + tiled_output[i + tgt_offset] = output[i + src_offset]; + } +} + +template +void invokeTileEncoderResults(T* tiled_output, + int* tiled_sequence_length, + const T* output, + const int* sequence_length, + const size_t batch_size, + const size_t beam_width, + const size_t mem_max_seq_len, + const size_t d_model, + cudaStream_t stream) +{ + // tiled_output: [batch_size, beam_width, mem_max_seq_len, d_model] + // tiled_sequence_length: [batch_size, beam_width] + + // output: [batch_size, mem_max_seq_len, d_model] + // sequence_length [batch_size] + + dim3 grid(batch_size, beam_width, mem_max_seq_len); + bool is_half2 = (std::is_same::value) && (d_model % 2 == 0); + + if (is_half2) { + using T2 = typename TypeConverter::Type; // fp16 to half2, bf16 to bf162 + dim3 block(min(512, (int)(d_model / 2))); + tileEncoderResults<<>>((T2*)tiled_output, + tiled_sequence_length, + (const T2*)output, + sequence_length, + batch_size, + beam_width, + d_model / 2); + } + else { + dim3 block(min(512, (int)d_model)); + tileEncoderResults<<>>( + tiled_output, tiled_sequence_length, output, sequence_length, batch_size, beam_width, d_model); + } +} + +template void invokeTileEncoderResults(float* tiled_output, + int* tiled_sequence_length, + const float* output, + const int* sequence_length, + const size_t batch_size, + const size_t beam_width, + const size_t mem_max_seq_len, + const size_t d_model, + cudaStream_t stream); + +template void invokeTileEncoderResults(half* tiled_output, + int* tiled_sequence_length, + const half* output, + const int* sequence_length, + const size_t batch_size, + const size_t beam_width, + const size_t mem_max_seq_len, + const size_t d_model, + cudaStream_t stream); + +template void invokeTileEncoderResults(half2* tiled_output, + int* tiled_sequence_length, + const half2* output, + const int* sequence_length, + const size_t batch_size, + const size_t beam_width, + const size_t mem_max_seq_len, + const size_t d_model, + cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, + int* tiled_sequence_length, + const __nv_bfloat16* output, + const int* sequence_length, + const size_t batch_size, + const size_t beam_width, + const size_t mem_max_seq_len, + const size_t d_model, + cudaStream_t stream); +#endif + +__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, + const bool* finished, + const float* cum_log_probs, + const int batch_size, + const int beam_width) +{ + const int bid = blockIdx.x; + const int tgt_start_idx = beam_hyps.num_beams[bid]; + if (beam_hyps.is_done[bid]) { + return; + } + for (int i = 0; i < beam_width; i++) { + if (threadIdx.x == 0) { + const int src_beam_idx = bid * beam_width + i; + const int tgt_beam_idx = bid * beam_width * 2 + i + tgt_start_idx; + + const int length = beam_hyps.sequence_lengths_src[src_beam_idx]; + + beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] = + beam_hyps.output_ids_src[length * batch_size * beam_width + src_beam_idx]; + if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) { + beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + length] = + beam_hyps.log_probs_src[length * batch_size * beam_width + src_beam_idx]; + } + int prev_id = beam_hyps.parent_ids_src[length * batch_size * beam_width + src_beam_idx]; + for (int j = length - 1; j >= 0; j--) { + // output_ids_tgt need to use max_seq_len + 1 because its shape is + // [bs, beam_width, max_seq_len + 1] + beam_hyps.output_ids_tgt[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] = + beam_hyps.output_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id]; + if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr) { + beam_hyps.log_probs[(tgt_beam_idx) * (beam_hyps.max_seq_len + 1) + j] = + beam_hyps.log_probs_src[j * batch_size * beam_width + bid * beam_width + prev_id]; + } + prev_id = beam_hyps.parent_ids_src[j * batch_size * beam_width + bid * beam_width + prev_id]; + } + beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = length; + + beam_hyps.normed_scores[tgt_beam_idx] = apply_length_penalty( + cum_log_probs[src_beam_idx], finished[src_beam_idx] ? length + 1 : length, beam_hyps.length_penalty); + beam_hyps.cum_log_probs[tgt_beam_idx] = cum_log_probs[src_beam_idx]; + + beam_hyps.num_beams[bid]++; + } + } +} + +void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, + const bool* finished, + const float* cum_log_probs, + const int batch_size, + const int beam_width, + cudaStream_t stream) +{ + insertUnfinishedPath<<>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width); +} + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/beam_search_topk_kernels.h b/src/fastertransformer/kernels/beam_search_topk_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..60732a5943a6c1e305a6f76a46e0177007474d87 --- /dev/null +++ b/src/fastertransformer/kernels/beam_search_topk_kernels.h @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#pragma once + +namespace fastertransformer { + +// In original beam search implementation, if a beam is finished, we set it as finished +// and only continue to do beam search on remain beams (namely, beam_width - 1 beams in next step) +// +// In this implementation, when a beam is finished, we trace the path and record it in output_ids_tgt, +// and also record the normalized scores. And the beam search continue to use `beam_width` beams in +// next step. +// +// After we collect `beam_width` beams, we will sort them by their norm_scores. +struct BeamHypotheses { + int* output_ids_tgt = nullptr; + int* sequence_lengths_tgt = nullptr; + float* cum_log_probs = nullptr; // cum_log + float* normed_scores = nullptr; // cum_log / (length**length_penalty) + float* log_probs = nullptr; // log probs of each generated token + float* min_normed_scores = nullptr; // record the min normed scores for each batch + int* num_beams = nullptr; // the number of finished beams we collect + bool* is_done = nullptr; + + // Used to set inputs + const int* output_ids_src; + const int* parent_ids_src; + const int* sequence_lengths_src; + const int* end_ids; + const float* log_probs_src; + + // some variables for kernels + int step; + int ite; + int batch_size; + int local_batch_size; + int max_seq_len; + float length_penalty; + + bool early_stopping = true; + bool is_return_normed_score = true; // return normed_cum_log_probs or cum_log_probs +}; + +template +void invokeTopkBeamSearch(void* workspace, + size_t& workspace_size, + T* log_probs, + int* ids, + BeamHypotheses* beam_hyps, + const bool* finished, + const int* sequence_lengths, + const int batch_size, + const int beam_width, + const int vocab_size_padded_, + const T diversity_rate, + const float length_penalty, + const int* end_ids, + cudaStream_t stream); + +template +void invokeTileEncoderResults(T* tiled_encoder_output, + int* tiled_encoder_sequence_length, + const T* encoder_output, + const int* encoder_sequence_length, + const size_t batch_size, + const size_t beam_width, + const size_t mem_max_seq_len, + const size_t d_model, + cudaStream_t stream); + +void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, + const bool* finished, + const float* cum_log_probs, + const int batch_size, + const int beam_width, + cudaStream_t stream); + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.cu b/src/fastertransformer/kernels/bert_preprocess_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..a57161c8596659298631d2c054ae60e731912d7d --- /dev/null +++ b/src/fastertransformer/kernels/bert_preprocess_kernels.cu @@ -0,0 +1,470 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "bert_preprocess_kernels.h" +#include "src/fastertransformer/utils/cuda_bf16_fallbacks.cuh" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" + +namespace fastertransformer { + +__global__ void getPaddingOffsetAndCuSeqLensKernel(size_t* h_valid_word_num, + int* tmp_mask_offset, + int* cu_seqlens, + const int* sequence_length, + const int batch_size, + const int max_seq_len) +{ + // do cumulated sum + int total_seq_len = 0; + int cum_offset = 0; + int index = 0; + const bool calculate_cu_seqlens = cu_seqlens != nullptr; + for (int i = 0; i < batch_size; i++) { + const int seq_len = sequence_length[i]; + if (calculate_cu_seqlens) { + cu_seqlens[i] = total_seq_len; + } + for (int j = 0; j < seq_len; j++) { + tmp_mask_offset[index] = cum_offset; + index++; + } + cum_offset += max_seq_len - seq_len; + total_seq_len += seq_len; + } + if (calculate_cu_seqlens) { + cu_seqlens[batch_size] = total_seq_len; + } + h_valid_word_num[0] = (size_t)total_seq_len; +} + +void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num, + size_t* h_token_num, + int* tmp_mask_offset, + int* cu_seqlens, + const int* sequence_lengths, + const int batch_size, + const int max_seq_len, + cudaStream_t stream) +{ + h_pinned_token_num[0] = 0; + getPaddingOffsetAndCuSeqLensKernel<<<1, 1, 0, stream>>>( + h_pinned_token_num, tmp_mask_offset, cu_seqlens, sequence_lengths, batch_size, max_seq_len); + while (((volatile size_t*)h_pinned_token_num)[0] == 0) {}; + h_token_num[0] = h_pinned_token_num[0]; + sync_check_cuda_error(); +} + +template +__global__ void buildEncoderAttentionMaskKernel(T* attention_mask, const int* sequence_lengths, const int max_seq_len) +{ + // sequence_lengths: [batch_size] + // attention_mask: [batch_size, 1, max_seq_len, max_seq_len] + attention_mask += blockIdx.x * max_seq_len * max_seq_len; + const int length = sequence_lengths[blockIdx.x]; + for (int i = threadIdx.x; i < max_seq_len * max_seq_len; i += blockDim.x) { + // int row_id = i / max_seq_len; + int col_id = i % max_seq_len; + // if (row_id < length && col_id < length) { + // TODO (bhsueh) check this modification is ok or not on other rmodel + if (col_id < length) { + attention_mask[i] = (T)(1.0f); + } + else { + attention_mask[i] = (T)(0.0f); + } + } +} + +template +void invokeBuildEncoderAttentionMask( + T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream) +{ + buildEncoderAttentionMaskKernel<<>>(attention_mask, sequence_lengths, max_seq_len); +} + +template void invokeBuildEncoderAttentionMask(float* attention_mask, + const int* sequence_lengths, + const int batch_size, + const int max_seq_len, + cudaStream_t stream); +template void invokeBuildEncoderAttentionMask(half* attention_mask, + const int* sequence_lengths, + const int batch_size, + const int max_seq_len, + cudaStream_t stream); +#ifdef ENABLE_FP8 +template void invokeBuildEncoderAttentionMask(__nv_fp8_e4m3* attention_mask, + const int* sequence_lengths, + const int batch_size, + const int max_seq_len, + cudaStream_t stream); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeBuildEncoderAttentionMask(__nv_bfloat16* attention_mask, + const int* sequence_lengths, + const int batch_size, + const int max_seq_len, + cudaStream_t stream); +#endif + +__global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, const int* sequence_length, const int batch_size) +{ + // use for get tensorrt fused mha padding offset + // when we remove the padding + + extern __shared__ int tmp_offset[]; + if (threadIdx.x == 0) { + tmp_offset[0] = 0; + for (int i = 0; i < batch_size; i++) { + tmp_offset[i + 1] = tmp_offset[i] + sequence_length[i]; + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < batch_size + 1; i += blockDim.x) { + trt_mha_padding_offset[i] = tmp_offset[i]; + } +} + +void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, + const int* sequence_length, + const int batch_size, + cudaStream_t stream) +{ + getTrtPaddingOffsetKernel<<<1, 256, sizeof(int) * (batch_size + 1), stream>>>( + trt_mha_padding_offset, sequence_length, batch_size); +} + +__global__ void getTrtPaddingOffsetKernel(int* trt_mha_padding_offset, + const int* sequence_length, + const int request_batch_size, + const int request_seq_len) +{ + // use for get tensorrt fused mha padding offset + // when we keep the padding + + extern __shared__ int tmp_offset[]; + if (threadIdx.x == 0) { + tmp_offset[0] = 0; + for (int i = 0; i < request_batch_size; i++) { + tmp_offset[i * 2 + 1] = tmp_offset[i * 2] + sequence_length[i]; + tmp_offset[i * 2 + 2] = request_seq_len * (i + 1); + } + } + __syncthreads(); + + for (int i = threadIdx.x; i < 2 * request_batch_size + 1; i += blockDim.x) { + trt_mha_padding_offset[i] = tmp_offset[i]; + } +} + +void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, + const int* sequence_length, + const int request_batch_size, + const int request_seq_len, + cudaStream_t stream) +{ + getTrtPaddingOffsetKernel<<<1, 256, sizeof(int) * (2 * request_batch_size + 1), stream>>>( + trt_mha_padding_offset, sequence_length, request_batch_size, request_seq_len); +} + +template +__global__ void rebuild_sequence_length_padding(const T* src, T* dst, const int* padding_offset, const int n) +{ + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int dst_seq_id = bid + padding_offset[bid]; + const int src_seq_id = bid; + + for (int i = tid; i < n; i += blockDim.x) { + dst[dst_seq_id * n + i] = src[src_seq_id * n + i]; + } +} + +template +void invokeRebuildPadding( + T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream) +{ + // src: [token_num, hidden_dim] + // dst: [batch_size*max_seq_len, hidden_dim] + rebuild_sequence_length_padding<<>>(src, dst, padding_offset, hidden_dim); +} + +template +void invokeRebuildPadding( + T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); +template void invokeRebuildPadding(float* dst, + const float* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); +template void invokeRebuildPadding(half* dst, + const half* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeRebuildPadding(__nv_bfloat16* dst, + const __nv_bfloat16* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); +#endif // ENABLE_BF16 + +#ifdef ENABLE_FP8 +template void invokeRebuildPadding(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); +#endif // ENABLE_FP8 + +template +__global__ void remove_padding(T* tgt, const T* src, const int* padding_offset, const int n) +{ + const int tid = threadIdx.x; + const int bid = blockIdx.x; + const int src_seq_id = bid + padding_offset[bid]; + const int tgt_seq_id = bid; + + for (int i = tid; i < n; i += blockDim.x) { + tgt[tgt_seq_id * n + i] = src[src_seq_id * n + i]; + } +} + +template +void invokeRemovePadding( + T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream) +{ + remove_padding<<>>(dst, src, padding_offset, hidden_dim); +} + +template void invokeRemovePadding(float* dst, + const float* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); + +template void invokeRemovePadding(half* dst, + const half* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); +#ifdef ENABLE_FP8 +template void invokeRemovePadding(__nv_fp8_e4m3* dst, + const __nv_fp8_e4m3* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeRemovePadding(__nv_bfloat16* dst, + const __nv_bfloat16* src, + const int* padding_offset, + const int token_num, + const int hidden_dim, + cudaStream_t stream); +#endif + +template +__global__ void buildRelativeAttentionBias(T* relative_attention_bias, + const T* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance) +{ + + const int head_id = blockIdx.x; + for (int seq_id = threadIdx.x; seq_id < seq_len * seq_len; seq_id += blockDim.x) { + int row_id = seq_id / seq_len; + int col_id = seq_id % seq_len; + + int relative_position = col_id - row_id; + + int relative_buckets = 0; + int tmp_num_bucket = num_bucket; + if (is_bidirectional) { + tmp_num_bucket /= 2; + if (relative_position > 0) { + relative_buckets += tmp_num_bucket; + } + else { + relative_position *= -1; + } + } + else { + relative_position = abs(relative_position); + } + + int max_exact = tmp_num_bucket / 2; + bool is_small = relative_position < max_exact; + + int relative_position_if_large = + max_exact + + (int)(logf(relative_position * 1.0f / max_exact) / logf((float)max_distance / max_exact) + * (tmp_num_bucket - max_exact)); + + relative_position_if_large = min(relative_position_if_large, tmp_num_bucket - 1); + + relative_buckets += is_small ? relative_position : relative_position_if_large; + + relative_attention_bias[head_id * seq_len * seq_len + seq_id] = + relative_attention_bias_table[head_id * num_bucket + relative_buckets]; + } +} + +template +void invokeBuildRelativeAttentionBias(T* relative_attention_bias, + const T* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, + const PositionEmbeddingType position_embedding_type, + cudaStream_t stream) +{ + if (position_embedding_type == PositionEmbeddingType::absolute) { + return; + } + dim3 grid(head_num); + dim3 block(256); + buildRelativeAttentionBias<<>>(relative_attention_bias, + relative_attention_bias_table, + head_num, + seq_len, + num_bucket, + is_bidirectional, + max_distance); +} + +template void invokeBuildRelativeAttentionBias(float* relative_attention_bias, + const float* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, + const PositionEmbeddingType position_embedding_type, + cudaStream_t stream); + +template void invokeBuildRelativeAttentionBias(half* relative_attention_bias, + const half* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, + const PositionEmbeddingType position_embedding_type, + cudaStream_t stream); + +#ifdef ENABLE_BF16 +template void invokeBuildRelativeAttentionBias(__nv_bfloat16* relative_attention_bias, + const __nv_bfloat16* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, + const PositionEmbeddingType position_embedding_type, + cudaStream_t stream); +#endif + +#ifdef ENABLE_FP8 + +template +__global__ void getLastTokenDequantize(getLastTokenDequantizeParam param) +{ + param.output[blockIdx.x * param.d_model + threadIdx.x] = + (T_OUT)((float)param.input[blockIdx.x * param.max_seq_len * param.d_model + threadIdx.x] + * __ldg(param.input_scale)); +} + +template +void invokeGetLastTokenDequantize(getLastTokenDequantizeParam param) +{ + FT_CHECK(param.d_model <= 1024); + getLastTokenDequantize<<>>(param); +} + +template void invokeGetLastTokenDequantize<__nv_bfloat16, __nv_fp8_e4m3>( + getLastTokenDequantizeParam<__nv_bfloat16, __nv_fp8_e4m3> param); + +template +__global__ void quantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam param) +{ + for (int i = threadIdx.x; i < param.d_model; i += blockDim.x) { + int padded_row_id = blockIdx.x + (param.padding_offset == nullptr ? 0 : param.padding_offset[blockIdx.x]); + if (quantize_mode == QUANTIZE_MODE::PER_TENSOR) { + param.dst[padded_row_id * param.d_model + i] = + (T_OUT)((float)param.src[blockIdx.x * param.d_model + i] * __ldg(param.scale)); + } + else if (quantize_mode == QUANTIZE_MODE::PER_CHANNEL) { + param.dst[padded_row_id * param.d_model + i] = + (T_OUT)((float)param.src[blockIdx.x * param.d_model + i] * __ldg(param.scale + i)); + } + } +} + +template<> +__global__ void +quantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam param) +{ + int padded_row_id = blockIdx.x + (param.padding_offset == nullptr ? 0 : __ldg(¶m.padding_offset[blockIdx.x])); + __nv_fp8x4_e4m3* src_ptr = ((__nv_fp8x4_e4m3*)param.src) + blockIdx.x * (param.d_model / 4); + half2* dst_ptr = ((half2*)param.dst) + padded_row_id * (param.d_model / 2); + half2 scale = cuda_cast(__ldg(param.scale)); + for (int i = threadIdx.x; i < param.d_model / 4; i += blockDim.x) { + half2 val_0; + half2 val_1; + fp8x4_e4m3_to_half2(&val_0, &val_1, src_ptr + i); + + val_0 = hmul2(val_0, scale); + val_1 = hmul2(val_1, scale); + + dst_ptr[2 * i + 0] = val_0; + dst_ptr[2 * i + 1] = val_1; + } +} + +template +void invokeQuantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam param) +{ + dim3 grid(param.token_num); + dim3 block(param.d_model); + FT_CHECK(block.x <= 1024); + if (block.x % 4 == 0) { + block.x /= 4; + } + quantizeMatrixRebuildPadding<<>>(param); +} + +template void invokeQuantizeMatrixRebuildPadding( + QuantizeMatrixRebuildPaddingParam param); + +#endif + +} // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/bert_preprocess_kernels.h b/src/fastertransformer/kernels/bert_preprocess_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..2de48657dc92fffa02e0b033e73fec0f3580c691 --- /dev/null +++ b/src/fastertransformer/kernels/bert_preprocess_kernels.h @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "src/fastertransformer/kernels/gen_relative_pos_bias.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include +#include +#ifdef ENABLE_FP8 +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#endif // ENABLE_FP8 + +namespace fastertransformer { + +void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num, + size_t* h_token_num, + int* tmp_mask_offset, + int* cu_seqlens, + const int* sequence_length, + const int batch_size, + const int max_seq_len, + cudaStream_t stream); + +inline void invokeGetPaddingOffset(size_t* h_pinned_token_num, + size_t* h_token_num, + int* tmp_mask_offset, + const int* sequence_length, + const int batch_size, + const int max_seq_len, + cudaStream_t stream) +{ + invokeGetPaddingOffsetAndCuSeqLens( + h_pinned_token_num, h_token_num, tmp_mask_offset, nullptr, sequence_length, batch_size, max_seq_len, stream); +} + +template +void invokeBuildEncoderAttentionMask( + T* attention_mask, const int* sequence_lengths, const int batch_size, const int max_seq_len, cudaStream_t stream); + +void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, + const int* sequence_length, + const int request_batch_size, + cudaStream_t stream); + +void invokeGetTrtPaddingOffset(int* trt_mha_padding_offset, + const int* sequence_length, + const int request_batch_size, + const int request_seq_len, + cudaStream_t stream); + +template +void invokeRebuildPadding( + T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); + +template +void invokeRemovePadding( + T* dst, const T* src, const int* padding_offset, const int token_num, const int hidden_dim, cudaStream_t stream); + +template +void invokeBuildRelativeAttentionBias(T* relative_attention_bias, + const T* relative_attention_bias_table, + const int head_num, + const int seq_len, + const int num_bucket, + const bool is_bidirectional, + const int max_distance, + const PositionEmbeddingType position_embedding_type, + cudaStream_t stream); + +template +struct getLastTokenDequantizeParam { + T_OUT* const output; + T_IN const* const input; + float const* const input_scale; + + const int batch_size; + const int max_seq_len; + const int d_model; + cudaStream_t stream; +}; + +template +void invokeGetLastTokenDequantize(getLastTokenDequantizeParam param); + +#ifdef ENABLE_FP8 +template +struct QuantizeMatrixRebuildPaddingParam { + T_OUT* dst; + const T_IN* src; + const int* padding_offset; + const int token_num; + const int d_model; + const float* scale; + cudaStream_t stream; +}; + +template +void invokeQuantizeMatrixRebuildPadding(QuantizeMatrixRebuildPaddingParam param); +#endif // ENABLE_FP8 + +} // namespace fastertransformer diff --git a/src/fastertransformer/kernels/custom_ar_kernels.cu b/src/fastertransformer/kernels/custom_ar_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..af8aee128f6f6e96abb4f58bc5bb7638836fcf85 --- /dev/null +++ b/src/fastertransformer/kernels/custom_ar_kernels.cu @@ -0,0 +1,398 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "custom_ar_kernels.h" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" + +namespace fastertransformer { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t hadd2(const uint32_t& a, const uint32_t& b) +{ + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t fadd(const uint32_t& a, const uint32_t& b) +{ + uint32_t c; + asm volatile("add.f32 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_addr) +{ +#if __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#else + __threadfence_system(); + asm volatile("st.global.volatile.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t* flag_addr) +{ +#if __CUDA_ARCH__ >= 700 + asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); +#else + asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Type Converter that packs data format to 128 bits data type +template +struct ARTypeConverter { + using Type = uint4; +}; + +#ifdef ENABLE_BF16 +template<> +struct ARTypeConverter<__nv_bfloat16> { + using Type = bf168; +}; +#endif + +// add two 128b data +template +inline __device__ T_IN add128b(T_IN a, T_IN b); + +template<> +inline __device__ uint4 add128b(uint4 a, uint4 b) +{ + uint4 c; + c.x = hadd2(a.x, b.x); + c.y = hadd2(a.y, b.y); + c.z = hadd2(a.z, b.z); + c.w = hadd2(a.w, b.w); + return c; +} + +template<> +inline __device__ uint4 add128b(uint4 a, uint4 b) +{ + uint4 c; + c.x = fadd(a.x, b.x); + c.y = fadd(a.y, b.y); + c.z = fadd(a.z, b.z); + c.w = fadd(a.w, b.w); + return c; +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ bf168 add128b(bf168 a, bf168 b) +{ + bf168 c; + c.x = bf16hadd2(a.x, b.x); + c.y = bf16hadd2(a.y, b.y); + c.z = bf16hadd2(a.z, b.z); + c.w = bf16hadd2(a.w, b.w); + return c; +} +#endif + +// init 128bits data with 0 +template +inline __device__ T init_packed_type(); + +template<> +inline __device__ uint4 init_packed_type() +{ + return make_uint4(0u, 0u, 0u, 0u); +} + +#ifdef ENABLE_BF16 +template<> +inline __device__ bf168 init_packed_type() +{ + bf168 val; + uint4& val_u = reinterpret_cast(val); + val_u = make_uint4(0u, 0u, 0u, 0u); + return val; +} +#endif + +template +static __global__ void oneShotAllReduceKernel(AllReduceParams params) +{ + // The block index. + const int bidx = blockIdx.x; + // The thread index with the block. + const int tidx = threadIdx.x; + + // The number of elements packed into one for comms + static constexpr int NUM_ELTS = std::is_same::value ? 4 : 8; + + // Packed data type for comms + using PackedType = typename ARTypeConverter::Type; + + // The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128). + size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS; + // The end of the segment computed by that block. + size_t max_offset = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank); + + // Synchronize the ranks. + volatile uint32_t* barrier_d = params.peer_barrier_ptrs[params.local_rank]; + if (tidx < RANKS_PER_NODE) { + // The 1st block notifies the other ranks. + if (bidx == 0) { + params.peer_barrier_ptrs[tidx][params.local_rank] = params.barrier_flag; + } + + // Busy-wait until all ranks are ready. + while (barrier_d[tidx] < params.barrier_flag) {} + } + + // Make sure we can move on... + __syncthreads(); + + // The source pointers. Distributed round-robin for the different warps. + const T* src_d[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + src_d[ii] = params.peer_comm_buffer_ptrs[rank]; + } + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t iter_offset = offset; iter_offset < max_offset; iter_offset += blockDim.x * NUM_ELTS) { + // Iterate over the different ranks/devices on the node to load the values. + PackedType vals[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + vals[ii] = reinterpret_cast(&src_d[ii][iter_offset])[0]; + } + + // Sum the values from the different ranks. + PackedType sums = init_packed_type(); +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + sums = add128b(sums, vals[ii]); + } + + // Store to the destination buffer. + reinterpret_cast(¶ms.local_output_buffer_ptr[iter_offset])[0] = sums; + } +} + +template +static __global__ void twoShotAllReduceKernel(AllReduceParams params) +{ + + // The block index. + const int bidx = blockIdx.x; + // The thread index with the block. + const int tidx = threadIdx.x; + + // The number of elements packed into one for comms + static constexpr int NUM_ELTS = std::is_same::value ? 4 : 8; + + // Packed data type for comms + using PackedType = typename ARTypeConverter::Type; + + // The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128). + size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS + params.rank_offset; + // The end of the segment computed by that block. + size_t max_offset = min(offset + params.elts_per_block, params.elts_total); + + // Synchronize the ranks. + volatile uint32_t* barrier_d = params.peer_barrier_ptrs[params.local_rank]; + if (tidx < RANKS_PER_NODE) { + // The 1st block notifies the other ranks. + if (bidx == 0) { + params.peer_barrier_ptrs[tidx][params.local_rank] = params.barrier_flag; + } + + // Busy-wait until all ranks are ready. + while (barrier_d[tidx] < params.barrier_flag) {} + } + + // Make sure we can move on... + __syncthreads(); + + // The source pointers. Distributed round-robin for the different warps. + T* src_d[RANKS_PER_NODE]; + // The destination ranks for round-robin gathering + size_t dst_rank[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + src_d[ii] = params.peer_comm_buffer_ptrs[rank]; + dst_rank[ii] = rank; + } + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t local_offset = offset; local_offset < max_offset; local_offset += blockDim.x * NUM_ELTS) { + + // Iterate over the different ranks/devices on the node to load the values. + PackedType vals[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + vals[ii] = reinterpret_cast(&src_d[ii][local_offset])[0]; + } + + // Sum the values from the different ranks. + PackedType sums = init_packed_type(); +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + sums = add128b(sums, vals[ii]); + } + + // Store to the local buffer. + reinterpret_cast(&src_d[0][local_offset])[0] = sums; + } + + // sync threads to make sure all block threads have the sums + __syncthreads(); + + // barreris among the blocks with the same idx (release-acuqire semantics) + if (tidx < RANKS_PER_NODE) { + // The all blocks notifies the other ranks. + uint32_t flag_block_offset = RANKS_PER_NODE + bidx * RANKS_PER_NODE; + st_flag_release(params.barrier_flag, params.peer_barrier_ptrs[tidx] + flag_block_offset + params.local_rank); + + // Busy-wait until all ranks are ready. + uint32_t rank_barrier = 0; + uint32_t* peer_barrier_d = params.peer_barrier_ptrs[params.local_rank] + flag_block_offset + tidx; + do { + ld_flag_acquire(rank_barrier, peer_barrier_d); + } while (rank_barrier != params.barrier_flag); + } + + // sync threads to make sure all other ranks has the final partial results + __syncthreads(); + + // Gather all needed elts from other intra-node ranks + for (size_t local_offset = offset; local_offset < max_offset; local_offset += blockDim.x * NUM_ELTS) { +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + // use round-robin gathering from other ranks + int offset_rank = local_offset + (dst_rank[ii] - params.local_rank) * params.elts_per_rank; + reinterpret_cast(¶ms.local_output_buffer_ptr[offset_rank])[0] = + reinterpret_cast(&src_d[dst_rank[ii]][offset_rank])[0]; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void kernelLaunchConfig( + int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo, size_t data_type_bytes) +{ + assert(data_type_bytes == 2 || data_type_bytes == 4); + // NOTE: need to support FP16 and FP32 + size_t elts_per_thread = 16 / data_type_bytes; + size_t elts_per_warp = (16 * WARP_SIZE) / data_type_bytes; + switch (kernel_algo) { + case 0: { // one stage all reduce algo + assert(elts % elts_per_warp == 0); + if (elts < (elts_per_thread * DEFAULT_BLOCK_SIZE)) { // local reduce + threads_per_block = ((elts + elts_per_warp - 1) / elts_per_warp) * WARP_SIZE; + blocks_per_grid = 1; + } + else { // local reduce + if (elts % (elts_per_thread * threads_per_block) == 0) { + blocks_per_grid = + (elts + elts_per_thread * threads_per_block - 1) / (elts_per_thread * threads_per_block); + // NOTE: need to adjust here + if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) { + int iter_factor = 1; + while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) { + iter_factor += 1; + } + blocks_per_grid /= iter_factor; + } + } + else { + int total_threads = elts / elts_per_thread; + blocks_per_grid = 1; + while (total_threads % blocks_per_grid != 0 + || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) { + blocks_per_grid += 1; + } + threads_per_block = total_threads / blocks_per_grid; + } + } + break; + } + case 1: { // two stage all reduce algo + int total_threads = elts / RANKS_PER_NODE / RANKS_PER_NODE; + assert(elts / RANKS_PER_NODE % RANKS_PER_NODE == 0 && total_threads % WARP_SIZE == 0); + + while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) { + blocks_per_grid += 1; + } + + threads_per_block = total_threads / blocks_per_grid; + + // NOTE: need to adjust here + if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) { + int iter_factor = 1; + while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) { + iter_factor += 1; + } + blocks_per_grid /= iter_factor; + } + break; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream) +{ + size_t elts_total = param.elts_total; + int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; + int kernel_algo = 1; + if (elts_total * sizeof(T) <= DEFALUT_ALGO_AR_SIZE_THRESHOLD) { + kernel_algo = 0; + } + + kernelLaunchConfig(blocks_per_grid, threads_per_block, elts_total, kernel_algo, sizeof(T)); + + if (kernel_algo == 0) { + param.elts_per_rank = elts_total; + param.elts_per_block = param.elts_per_rank / blocks_per_grid; + oneShotAllReduceKernel<<>>(param); + } + else { + param.elts_per_rank = param.elts_total / RANKS_PER_NODE; + param.elts_per_block = param.elts_per_rank / blocks_per_grid; + param.rank_offset = param.rank * param.elts_per_rank; + twoShotAllReduceKernel<<>>(param); + } +} + +// Template instantiation +template void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(AllReduceParams<__nv_bfloat16>& param, + cudaStream_t stream); +#endif +template void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream); +} // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/custom_ar_kernels.h b/src/fastertransformer/kernels/custom_ar_kernels.h new file mode 100644 index 0000000000000000000000000000000000000000..aba07658c5105955a6e96c6cb544863f602ed033 --- /dev/null +++ b/src/fastertransformer/kernels/custom_ar_kernels.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +#include "src/fastertransformer/utils/cuda_utils.h" + +#define CUSTOM_AR_SIZE_THRESHOLD 50331648 +#define MAX_ALL_REDUCE_BLOCKS 24 +#define FLAG(a) ((uint32_t)((a) % 0x146)) +#define RANKS_PER_NODE 8 +#define WARP_SIZE 32 +#define DEFAULT_BLOCK_SIZE 1024 +#define DEFALUT_ALGO_AR_SIZE_THRESHOLD 393216 + +namespace fastertransformer { + +#ifdef ENABLE_BF16 +typedef struct bf168 { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +} bf168; +#endif + +template +struct AllReduceParams { + size_t elts_total; + size_t elts_per_rank; + size_t elts_per_block; + size_t rank_offset; + size_t rank, local_rank, node_id; + uint32_t barrier_flag; + uint32_t* peer_barrier_ptrs[RANKS_PER_NODE]; + T* peer_comm_buffer_ptrs[RANKS_PER_NODE]; + T* local_output_buffer_ptr; +}; + +template +void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, cudaStream_t stream); + +void kernelLaunchConfig(int& blocks_per_grid, int& threads_per_block, size_t elts, int kernel_algo); + +} // namespace fastertransformer \ No newline at end of file diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu new file mode 100644 index 0000000000000000000000000000000000000000..2b5cb081d4533559e2f48c64167688787a117094 --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu @@ -0,0 +1,59 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +template +void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + switch (params.hidden_size_per_head) { + case 128: + mmha_launch_kernel(params, stream); + break; + default: + assert(false); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h new file mode 100644 index 0000000000000000000000000000000000000000..c56e87358be0240cfc4950d8a4d7332441b26cca --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The structure of parameters for the masked multihead attention kernel. +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. + +template +struct Multihead_attention_params_base { + + // The output buffer. Dimensions B x D. + T* out = nullptr; + + // The input Qs and the associated bias. Dimensions B x D and D, resp. + const T *q = nullptr, *q_bias = nullptr; + // The input Ks and the associated bias. Dimensions B x D and D, resp. + const T *k = nullptr, *k_bias = nullptr; + // The input Vs and the associated bias. Dimensions B x D and D, resp. + const T *v = nullptr, *v_bias = nullptr; + + // The cache for the Ks. The size must be at least B x L x D. + T* k_cache = nullptr; + // The cache for the Vs. The size must be at least B x L x D. + T* v_cache = nullptr; + // The indirections to use for cache when beam sampling. + const int* cache_indir = nullptr; + + // scales + const float* query_weight_output_scale = nullptr; + const float* attention_qk_scale = nullptr; + const float* attention_output_weight_input_scale_inv = nullptr; + + // Stride to handle the case when KQV is a single buffer + int stride = 0; + + // The batch size. + int batch_size = 0; + // The beam width + int beam_width = 0; + // The sequence length. + int memory_max_len = 0; + // The number of heads (H). + int num_heads = 0; + // The hidden dimension per head (Dh). + int hidden_size_per_head = 0; + // The per-head latent space reserved for rotary embeddings. + int rotary_embedding_dim = 0; + // The maximum length of input sentences. + int max_input_length = 0; + // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? + int timestep = 0; + // The current timestep of each sentences (support different timestep for different sentences) + + // The 1.f / sqrt(Dh). Computed on the host. + float inv_sqrt_dh = 0.0f; + + // Used when we have some input context like gpt + const int* total_padding_tokens = nullptr; + + const bool* masked_tokens = nullptr; + const int* prefix_prompt_lengths = nullptr; + int max_prefix_prompt_length = 0; + + const T* relative_attention_bias = nullptr; + int relative_attention_bias_stride = 0; + // The slope per head of linear position bias to attention score (H). + const T* linear_bias_slopes = nullptr; + + const T* ia3_key_weights = nullptr; + const T* ia3_value_weights = nullptr; + const int* ia3_tasks = nullptr; + + const float* qkv_scale_out = nullptr; + const float* attention_out_scale = nullptr; + int int8_mode = 0; +}; + +template +struct Multihead_attention_params: public Multihead_attention_params_base { + // allows to exist attention eary + bool* finished = nullptr; + + // required in case of masked attention with different length + const int* length_per_sample = nullptr; + + T** k_cache_per_sample = nullptr; + T** v_cache_per_sample = nullptr; + size_t kv_cache_per_sample_offset = 0; + bool k_cache_interleaved = true; +}; + +template +using Masked_multihead_attention_params = Multihead_attention_params; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu new file mode 100644 index 0000000000000000000000000000000000000000..928fadc89540b256d64ce9a8e12c96447e9c6d82 --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_utils.h" +#include +#include +#include + +#include "decoder_masked_multihead_attention_template.cuh" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel \ + <<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + // constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = params.timestep; + + FT_CHECK(params.cache_indir == nullptr); + + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 128, 128, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 128, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ddbbe446e227f7249638f0cc931fe7e4f2f9b9e2 --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh @@ -0,0 +1,1820 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" +#include +#include +#include + +// #define MMHA_USE_HMMA_FOR_REDUCTION + +// Below are knobs to extend FP32 accumulation for higher FP16 accuracy + +// Does not seem to affect the accuracy that much +// #define MMHA_USE_FP32_ACUM_FOR_FMA + +// Seems to slightly improve the accuracy +#define MMHA_USE_FP32_ACUM_FOR_OUT + +#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) + // Does not seem to improve the accuracy + //#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#endif + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. +// +// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use +// 64, 128 and 256 threads per block. +// +// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to +// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The +// cache buffer helps with memory accesses and contains keys with bias. +// +// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and +// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The +// values for x are chosen to create chunks of 16 bytes. +// +// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs +// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At +// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an +// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. +// +// After that loop, a parallel softmax is computed across the different Q * K^T values stored in +// shared memory. +// +// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many +// timesteps are computed by loop iteration. As with the keys, the values are read from a cache +// except for the current timestep. The layout of the cache buffer for the values is much simpler +// as it is [B, H, L, Dh]. +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_m_ {}; + +template<> +struct Qk_vec_m_ { + using Type = float; +}; +template<> +struct Qk_vec_m_ { + using Type = float2; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint2; +}; +template<> +struct Qk_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct Qk_vec_m_<__nv_bfloat16, 32> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 64> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 128> { + using Type = bf16_4_t; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 256> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 32> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 64> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 128> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 256> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_k_ { + using Type = typename Qk_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 32> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 64> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 128> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 256> { + using Type = float4; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_m_ {}; + +template<> +struct K_vec_m_ { + using Type = float; +}; +template<> +struct K_vec_m_ { + using Type = float2; +}; +template<> +struct K_vec_m_ { + using Type = float4; +}; +template<> +struct K_vec_m_ { + using Type = uint32_t; +}; +template<> +struct K_vec_m_ { + using Type = uint2; +}; +template<> +struct K_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct K_vec_m_<__nv_bfloat16, 4> { + using Type = __nv_bfloat162; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 2> { + using Type = bf16_4_t; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 1> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +// NOTE: THREADS_PER_KEY * sizeof(K_vec_m_) = 128 bytes +#ifdef ENABLE_FP8 +template<> +struct K_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct K_vec_m_<__nv_fp8_e4m3, 2> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_m_<__nv_fp8_e4m3, 1> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_k_ { + using Type = typename K_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct K_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct K_vec_k_<__nv_fp8_e4m3, 2> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_k_<__nv_fp8_e4m3, 1> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_m_ {}; + +template<> +struct V_vec_m_ { + using Type = float; +}; +template<> +struct V_vec_m_ { + using Type = float2; +}; +template<> +struct V_vec_m_ { + using Type = float4; +}; +template<> +struct V_vec_m_ { + using Type = uint32_t; +}; +template<> +struct V_vec_m_ { + using Type = uint2; +}; +template<> +struct V_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_m_<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +template<> +struct V_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 8> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 16> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_k_ { + using Type = typename V_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct V_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 8> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 16> { + using Type = float4; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct Qk_vec_acum_fp32_ {}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float4; +}; +// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_acum_fp32_ {}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct K_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct K_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif // MMHA_USE_FP32_ACUM_FOR_FMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template +struct V_vec_acum_fp32_ {}; + +template<> +struct V_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +// template<> +// struct V_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct V_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} +#ifdef ENABLE_FP8 +// fp8_t +template<> +__inline__ __device__ float vec_conversion(const __nv_fp8_e4m3& a) +{ + return float(a); +} +template<> +__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a) +{ + return __nv_fp8_e4m3(a); +} +// fp8_2_t +template<> +__inline__ __device__ float2 vec_conversion(const fp8_2_t& a) +{ + return float2(a); +} +template<> +__inline__ __device__ fp8_2_t vec_conversion(const float2& a) +{ + return fp8_2_t(a); +} +// fp8_4_t +template<> +__inline__ __device__ float4 vec_conversion(const fp8_4_t& a) +{ + return float4(a); +} +template<> +__inline__ __device__ fp8_4_t vec_conversion(const float4& a) +{ + return fp8_4_t(a); +} +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +{ +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = K_vec; +#endif + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) + { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) +{ + float4 c; + float zero = 0.f; + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +// Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float& dst, float src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t& dst, float src) +{ + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t& dst, float2 src) +{ + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) +{ + dst = __float2bfloat16(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, Float4_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4& dst, Float8_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +inline __device__ void convert_from_float(fp8_4_t& dst, float4 src) +{ + dst = fp8_4_t(src); +} +inline __device__ void convert_from_float(fp8_2_t& dst, float2 src) +{ + dst = fp8_2_t(src); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2& dst, float2 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4& dst, float4 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) +{ + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) +{ + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float float_from_int8(int8_t u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 float_from_int8(int16_t u) +{ + union { + int16_t int16; + int8_t int8[2]; + }; + int16 = u; + return make_float2(int8[0], int8[1]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 float_from_int8(int32_t u) +{ + union { + int32_t int32; + int8_t int8[4]; + }; + int32 = u; + return make_float4(int8[0], int8[1], int8[2], int8[3]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off +inline __device__ Float8_ float_from_int8(int64_t u) +{ + union { + int64_t int64; + int16_t int16[4]; + }; + int64 = u; + return Float8_ {float_from_int8(int16[0]), + float_from_int8(int16[1]), + float_from_int8(int16[2]), + float_from_int8(int16[3])}; +} +// clang-format on + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int8_t cast_to_int8(float val) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int32_t cast_to_int8(float4 val) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int8[0] = cast_to_int8(val.x); + int8[1] = cast_to_int8(val.y); + int8[2] = cast_to_int8(val.z); + int8[3] = cast_to_int8(val.w); + return int32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int64_t cast_to_int8(Float8_ val) +{ + union { + int8_t int8[8]; + int64_t int64; + }; + int8[0] = cast_to_int8(val.x.x); + int8[1] = cast_to_int8(val.x.y); + int8[2] = cast_to_int8(val.y.x); + int8[3] = cast_to_int8(val.y.y); + int8[4] = cast_to_int8(val.z.x); + int8[5] = cast_to_int8(val.z.y); + int8[6] = cast_to_int8(val.w.x); + int8[7] = cast_to_int8(val.w.y); + return int64; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct kernel_type_t { + using Type = T; +}; + +#ifdef ENABLE_FP8 +template<> +struct kernel_type_t<__nv_fp8_e4m3> { + using Type = float; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline size_t +smem_size_in_bytes(const Multihead_attention_params& params, int threads_per_value, int threads_per_block) +{ + using Tk = typename kernel_type_t::Type; + // The amount of shared memory needed to store the Q*K^T values in float. + const int max_timesteps = min(params.timestep, params.memory_max_len); + size_t qk_sz = div_up(max_timesteps + 1, 4) * 16; + + // The extra memory needed if we are not using floats for the final logits. + size_t logits_sz = 0; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(Tk) != 4) { + // TDOD + logits_sz = div_up(max_timesteps + 1, 4) * 4 * sizeof(Tk); + } +#endif + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(Tk) / 2; + + // The max. + return max(softmax_sz, red_sz); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ constexpr uint32_t shfl_mask(int threads) +{ + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template