Unverified Commit 4767b04d authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Support Runtime tensor parallelism (#158)

* works on interlm and vicuna

* support GQA

* remove comment

* update readme, add logger, default tp=1

* remove log
parent 981a4610
......@@ -84,10 +84,13 @@ python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-ch
python -m lmdeploy.turbomind.chat ./workspace
```
```{note}
When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind. It is recommended to use NVIDIA cards such as 3090, V100, A100, etc.
Disable GPU ECC can free up 10% memory, try `sudo nvidia-smi --ecc-config=0` and reboot system.
```
> **Note**<br />
> When inferring with FP16 precision, the InternLM-7B model requires at least 15.7G of GPU memory overhead on TurboMind. <br />
> It is recommended to use NVIDIA cards such as 3090, V100, A100, etc.
> Disable GPU ECC can free up 10% memory, try `sudo nvidia-smi --ecc-config=0` and reboot system.
> **Note**<br />
> Tensor parallel is available to perform inference on multiple GPUs. Add `--tp=<num_gpu>` on `chat` to enable runtime TP.
#### Serving
......@@ -163,6 +166,9 @@ Then adjust `workspace/triton_models/weights/config.ini`
Here is [quantization test results](./docs/en/quantization.md).
> **Warning**<br />
> runtime Tesnor Parallel for quantilized model is not available. Please setup `--tp` on `deploy` to enable static TP.
## Contributing
We appreciate all contributions to LMDeploy. Please refer to [CONTRIBUTING.md](.github/CONTRIBUTING.md) for the contributing guideline.
......
......@@ -83,10 +83,12 @@ python3 -m lmdeploy.serve.turbomind.deploy internlm-chat-7b /path/to/internlm-ch
python3 -m lmdeploy.turbomind.chat ./workspace
```
```{note}
turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。
关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。
```
> **Note**<br />
> turbomind 在使用 FP16 精度推理 InternLM-7B 模型时,显存开销至少需要 15.7G。建议使用 3090, V100,A100等型号的显卡。<br />
> 关闭显卡的 ECC 可以腾出 10% 显存,执行 `sudo nvidia-smi --ecc-config=0` 重启系统生效。
> **Note**<br />
> 使用 Tensor 并发可以利用多张 GPU 进行推理。在 `chat` 时添加参数 `--tp=<num_gpu>` 可以启动运行时 TP。
#### 部署推理服务
......@@ -162,6 +164,9 @@ python3 -m lmdeploy.lite.apis.kv_qparams \
这里是[量化测试结果](./docs/zh_cn/quantization.md)
> **Warning**<br />
> 量化部署不支持运行时 Tensor 并发。如果希望使用 Tensor 并发,需要在 deploy 时配置 tp 参数。
## 贡献指南
我们感谢所有的贡献者为改进和提升 LMDeploy 所作出的努力。请参考[贡献指南](.github/CONTRIBUTING.md)来了解参与项目贡献的相关指引。
......
......@@ -76,10 +76,11 @@ def main(model_path: str,
concurrency: int = 1,
input_seqlen: int = 0,
output_seqlen: int = 512,
test_round: int = 10):
test_round: int = 10,
tp: int = 1):
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = TurboMind(model_path=model_path)
tm_model = TurboMind(model_path=model_path, tp=tp)
warmup(tm_model, concurrency, output_seqlen)
......
......@@ -54,11 +54,11 @@ def sample_requests(
class Engine:
def __init__(self, model_path: str):
def __init__(self, model_path: str, tp: int = 1):
tokenizer_model_path = osp.join(model_path, 'triton_models',
'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = TurboMind(model_path=model_path)
tm_model = TurboMind(model_path=model_path, tp=tp)
self.tm_model = tm_model
self.tokenizer = tokenizer
......@@ -117,9 +117,10 @@ class Engine:
def main(dataset: str,
model_path: str,
concurrency: int = 1,
num_prompts: int = 1000):
num_prompts: int = 1000,
tp: int = 1):
engine = Engine(model_path)
engine = Engine(model_path, tp=tp)
tokenizer = engine.tokenizer
requests = sample_requests(dataset, num_prompts, tokenizer)
......
......@@ -52,7 +52,7 @@ def stream_callback(que, result, error):
def get_logger(log_file=None, log_level=logging.INFO):
"""Return the logger."""
from .utils import get_logger
from lmdeploy.turbomind.utils import get_logger
logger = get_logger('service.ft', log_file=log_file, log_level=log_level)
return logger
......
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import List, Optional, Union
from typing import List, Union
import numpy as np
import tritonclient.grpc as grpcclient
from tritonclient.utils import np_to_triton_dtype
logger_initialized = {}
def get_logger(name: str,
log_file: Optional[str] = None,
log_level: int = logging.INFO,
file_mode: str = 'w'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified, a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
# use logger in mmengine if exists.
try:
from mmengine.logging import MMLogger
if MMLogger.check_instance_created(name):
logger = MMLogger.get_instance(name)
else:
logger = MMLogger.get_instance(name,
logger_name=name,
log_file=log_file,
log_level=log_level,
file_mode=file_mode)
return logger
except Exception:
pass
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
# handle duplicate logs to the console
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
logger.setLevel(log_level)
logger_initialized[name] = True
return logger
def prepare_tensor(name, input_tensor):
"""Create grpcclient's InferInput instance according to a given tensor."""
......
......@@ -29,7 +29,10 @@ def valid_str(string, coding='utf-8'):
return ret
def main(model_path, session_id: int = 1, repetition_penalty: float = 1.0):
def main(model_path,
session_id: int = 1,
repetition_penalty: float = 1.0,
tp=1):
"""An example to perform model inference through the command line
interface.
......@@ -39,7 +42,7 @@ def main(model_path, session_id: int = 1, repetition_penalty: float = 1.0):
"""
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id)
tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
generator = tm_model.create_instance()
nth_round = 1
......
......@@ -13,6 +13,7 @@ from torch.nn.utils.rnn import pad_sequence
import lmdeploy
from lmdeploy.model import MODELS
from lmdeploy.turbomind.utils import get_logger
# TODO: find another way import _turbomind
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
......@@ -69,14 +70,11 @@ class TurboMind:
Args:
model_path (str): the path of turbomind's model
data_type (str): the data type
eos_id (int): eos token id
tp (int): tensor parallel
"""
def __init__(self,
model_path: str,
data_type: str = 'fp16',
eos_id: int = 2):
def __init__(self, model_path: str, eos_id: int = 2, tp: int = 1):
self.eos_id = eos_id
# TODO: support mpi
......@@ -84,8 +82,9 @@ class TurboMind:
node_num = 1
# read meta from model path
self.gpu_count = 1
self.gpu_count = tp
self.session_len = 2048
data_type = 'fp16'
ini_path = osp.join(model_path, 'triton_models/weights/config.ini')
with open(ini_path, 'r') as f:
parser = ConfigParser()
......@@ -97,10 +96,14 @@ class TurboMind:
section_name = 'llama'
if len(section_name) > 0:
self.gpu_count = parser.getint(section_name,
'tensor_para_size')
tp_cfg = parser.getint(section_name, 'tensor_para_size')
self.session_len = parser.getint(section_name, 'session_len')
if tp_cfg != 1 and tp_cfg != tp:
get_logger('turbomind').info(
f'found tp={tp_cfg} in config.ini.')
self.gpu_count = tp_cfg
self.model_name = parser.get(section_name, 'model_name')
data_type = parser.get(section_name, 'weight_type')
model = MODELS.get(self.model_name)()
self.stop_words = _stop_words(model.stop_words)
......
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from typing import Optional
logger_initialized = {}
def get_logger(name: str,
log_file: Optional[str] = None,
log_level: int = logging.INFO,
file_mode: str = 'w'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified, a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
Returns:
logging.Logger: The expected logger.
"""
# use logger in mmengine if exists.
try:
from mmengine.logging import MMLogger
if MMLogger.check_instance_created(name):
logger = MMLogger.get_instance(name)
else:
logger = MMLogger.get_instance(name,
logger_name=name,
log_file=log_file,
log_level=log_level,
file_mode=file_mode)
return logger
except Exception:
pass
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
# handle hierarchical names
# e.g., logger "a" is initialized, then logger "a.b" will skip the
# initialization since it is a child of "a".
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
# handle duplicate logs to the console
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
if log_file is not None:
# Here, the default behaviour of the official logger is 'a'. Thus, we
# provide an interface to change the file mode to the default
# behaviour.
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s')
for handler in handlers:
handler.setFormatter(formatter)
handler.setLevel(log_level)
logger.addHandler(handler)
logger.setLevel(log_level)
logger_initialized[name] = True
return logger
......@@ -21,6 +21,7 @@
#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/memory_utils.h"
#include <filesystem>
namespace turbomind {
......@@ -99,25 +100,135 @@ void mallocWeights(LlamaDenseWeight<T>& weights, bool bias)
}
template<typename T>
void loadWeights(LlamaDenseWeight<T>& w, std::string prefix, int rank, FtCudaDataType model_file_type)
void loadWeights(LlamaDenseWeight<T>& w,
std::string prefix,
int rank,
FtCudaDataType model_file_type,
size_t tensor_para_size,
int slice_dim = 0,
std::vector<size_t> slice_shape = {})
{
prefix += "." + std::to_string(rank);
auto max_prefix = prefix + "." + std::to_string(tensor_para_size - 1);
const auto type = model_file_type;
bool enable_slice = true;
// Disable slice if tensor param rank is 1
if (tensor_para_size <= 1) {
enable_slice = false;
}
else {
// Disable slice if weight has already been sliced
if (std::filesystem::exists(max_prefix + ".weight") || std::filesystem::exists(max_prefix + ".qweight")) {
TM_LOG_DEBUG("TP weight exists. Disable runtime TP.");
enable_slice = false;
}
}
size_t dim0 = w.input_dims;
size_t dim1 = w.output_dims;
if (enable_slice) {
// multiple tp size for slice stride
if (slice_dim == 0) {
dim0 = dim0 * tensor_para_size;
if (slice_shape.size() == 0) {
slice_shape = {dim0};
}
}
else {
dim1 = dim1 * tensor_para_size;
if (slice_shape.size() == 0) {
slice_shape = {dim1};
}
}
prefix += "." + std::to_string(0);
}
else {
prefix += "." + std::to_string(rank);
}
if (w.bias) {
loadWeightFromBin((T*)w.bias, {w.output_dims}, prefix + ".bias", type);
std::vector<ConcateSlice> bias_slices{};
if (enable_slice) {
if (slice_dim == 1) {
size_t start = 0;
ConcateSlice slice0{.slices = {{0, 1}}};
ConcateSlice slice1{.slices = {{}}};
for (auto len : slice_shape) {
size_t stride = len / tensor_para_size;
slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
start += len;
}
bias_slices = {slice0, slice1};
}
}
loadWeightFromBin((T*)w.bias, {1, dim1}, prefix + ".bias", type, bias_slices);
}
const size_t bit_size = getBitSize(w.type);
if (bit_size >= 16) { // fp16, fp32
loadWeightFromBin((T*)w.kernel, {w.input_dims, w.output_dims}, prefix + ".weight", type);
std::vector<ConcateSlice> weight_slices{};
if (enable_slice) {
if (slice_dim == 1) {
size_t start = 0;
ConcateSlice slice0{.slices = {{0, dim0}}};
ConcateSlice slice1{.slices = {{}}};
for (auto len : slice_shape) {
size_t stride = len / tensor_para_size;
slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
start += len;
}
weight_slices = {slice0, slice1};
}
else {
size_t start = 0;
ConcateSlice slice0{.slices = {}};
ConcateSlice slice1{.slices = {{0, dim1}}};
for (auto len : slice_shape) {
size_t stride = len / tensor_para_size;
slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
start += len;
}
weight_slices = {slice0, slice1};
}
}
loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type, weight_slices);
}
else { // int8, int4
const int factor = sizeof(float) * 8 / bit_size;
FT_CHECK(w.input_dims % factor == 0);
FT_CHECK(dim0 % factor == 0);
const auto f32_type = FtCudaDataType::FP32;
loadWeightFromBin((float*)w.kernel, {w.input_dims / factor, w.output_dims}, prefix + ".qweight", f32_type);
loadWeightFromBin((T*)w.scales, {w.output_dims}, prefix + ".scales", type);
loadWeightFromBin((T*)w.zeros, {w.output_dims}, prefix + ".zeros", type);
std::vector<ConcateSlice> weight_slices{};
std::vector<ConcateSlice> bias_slices{};
if (enable_slice) {
if (slice_dim == 1) {
size_t start = 0;
ConcateSlice slice0{.slices = {{0, dim0}}};
ConcateSlice slice1{.slices = {{}}};
for (auto len : slice_shape) {
size_t stride = len / tensor_para_size;
slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
start += len;
}
weight_slices = {slice0, slice1};
ConcateSlice bias_slice0{.slices = {{0, 1}}};
bias_slices = {bias_slice0, slice1};
}
else {
size_t start = 0;
ConcateSlice slice0{.slices = {}};
ConcateSlice slice1{.slices = {{0, dim1}}};
for (auto len : slice_shape) {
size_t stride = len / factor / tensor_para_size;
slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)});
start += len;
}
weight_slices = {slice0, slice1};
}
}
loadWeightFromBin((float*)w.kernel, {dim0 / factor, dim1}, prefix + ".qweight", f32_type, weight_slices);
loadWeightFromBin((T*)w.scales, {1, dim1}, prefix + ".scales", type, bias_slices);
loadWeightFromBin((T*)w.zeros, {1, dim1}, prefix + ".zeros", type, bias_slices);
}
}
......@@ -158,11 +269,17 @@ void LlamaDecoderLayerWeight<T>::loadModel(std::string dir_path, FtCudaDataType
(T*)self_attn_norm_weights, {hidden_units_}, dir_path + ".attention_norm.weight", model_file_type);
loadWeightFromBin((T*)ffn_norm_weights, {hidden_units_}, dir_path + ".ffn_norm.weight", model_file_type);
loadWeights(self_attn_weights.qkv, dir_path + ".attention.w_qkv", tensor_para_rank_, type);
loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type);
loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type);
loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type);
loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type);
loadWeights(self_attn_weights.qkv,
dir_path + ".attention.w_qkv",
tensor_para_rank_,
type,
tensor_para_size_,
1,
{head_num_ * size_per_head_, kv_head_num_ * size_per_head_, kv_head_num_ * size_per_head_});
loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_, 0);
loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_, 1);
loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1);
loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0);
// load kv_cache quant scale
// if file not exist, get empty vector
......
......@@ -301,16 +301,20 @@ template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size);
// loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or
// the product of the elements in shape is 0, this function will return an empty vector.
template<typename T>
std::vector<T> loadWeightFromBinHelper(std::vector<size_t> shape, std::string filename)
std::vector<T>
loadWeightFromBinHelper(std::vector<size_t> shape, std::string filename, std::vector<ConcateSlice> slices = {})
{
if (shape.size() > 2) {
printf("[ERROR] shape should have less than two dims \n");
return std::vector<T>();
}
size_t dim0 = shape[0], dim1 = 1;
if (shape.size() == 2) {
dim1 = shape[1];
}
if (slices.size() == 0) {
size_t size = dim0 * dim1;
if (size == 0) {
TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str());
......@@ -342,17 +346,113 @@ std::vector<T> loadWeightFromBinHelper(std::vector<size_t> shape, std::string fi
in.close();
// If we succeed, return an array with values.
return host_array;
}
else {
// concate all slices on the same dims
if (slices.size() != shape.size()) {
printf("[ERROR] slices should have same dims as shape \n");
return std::vector<T>();
}
// get slices
ConcateSlice slice0{.slices = {{0, dim0}}};
ConcateSlice slice1{.slices = {{0, dim1}}};
if (slices.size() > 0 && slices[0].slices.size() > 0) {
slice0 = slices[0];
}
if (shape.size() == 2 && slices[1].slices.size() > 0) {
slice1 = slices[1];
}
size_t w0 = 0;
for (auto& s : slice0.slices) {
if (s.second > dim0) {
s.second = dim0;
}
if (s.second < s.first) {
printf("[ERROR] slice0: end < start \n");
return std::vector<T>();
}
w0 += s.second - s.first;
}
size_t w1 = 0;
for (auto& s : slice1.slices) {
if (s.second > dim1) {
s.second = dim1;
}
if (s.second < s.first) {
printf("[ERROR] slice1: end < start \n");
return std::vector<T>();
}
w1 += s.second - s.first;
}
size_t size = w0 * w1;
size_t loaded_data_size = size * sizeof(T);
TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename + " with slice.");
if (size == 0) {
TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str());
return std::vector<T>();
}
std::vector<T> host_array(size);
std::ifstream in(filename, std::ios::in | std::ios::binary);
if (!in.is_open()) {
TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str());
return std::vector<T>();
}
char* host_ptr = (char*)host_array.data();
if (slice1.slices.size() == 0
|| (slice1.slices.size() == 1 && slice1.slices[0].second - slice1.slices[0].first == dim1)) {
for (auto& s : slice0.slices) {
size_t read_size = (s.second - s.first) * dim1 * sizeof(T);
size_t pos = s.first * dim1;
in.seekg(pos * sizeof(T));
in.read((char*)host_ptr, read_size);
host_ptr += read_size;
}
in.close();
return host_array;
}
{
for (auto& s0 : slice0.slices) {
// loop over outer slice
for (size_t line_id = s0.first; line_id < s0.second; ++line_id) {
// loop over lines
size_t pos0 = line_id * dim1;
for (auto& s1 : slice1.slices) {
// loop over inner slice
size_t pos = pos0 + s1.first;
size_t read_size = (s1.second - s1.first) * sizeof(T);
in.seekg(pos * sizeof(T));
in.read(host_ptr, read_size);
host_ptr += read_size;
}
}
}
in.close();
}
return host_array;
}
}
std::vector<float> loadArrayFromBin(std::vector<size_t> shape, std::string filename)
std::vector<float> loadArrayFromBin(std::vector<size_t> shape, std::string filename, std::vector<ConcateSlice> slices)
{
return loadWeightFromBinHelper<float>(shape, filename);
return loadWeightFromBinHelper<float>(shape, filename, slices);
}
template<typename T, typename T_IN>
int loadWeightFromBinFunc(T* ptr, std::vector<size_t> shape, std::string filename)
int loadWeightFromBinFunc(T* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices = std::vector<ConcateSlice>())
{
std::vector<T_IN> host_array = loadWeightFromBinHelper<T_IN>(shape, filename);
std::vector<T_IN> host_array = loadWeightFromBinHelper<T_IN>(shape, filename, slices);
if (host_array.empty()) {
return 0;
......@@ -371,49 +471,84 @@ int loadWeightFromBinFunc(T* ptr, std::vector<size_t> shape, std::string filenam
return 0;
}
template int loadWeightFromBinFunc<float, float>(float* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<half, float>(half* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<float, half>(float* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<half, half>(half* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<int8_t, int8_t>(int8_t* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<float, float>(float* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
template int loadWeightFromBinFunc<half, float>(half* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
template int loadWeightFromBinFunc<float, half>(float* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
template int loadWeightFromBinFunc<half, half>(half* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
template int loadWeightFromBinFunc<int8_t, int8_t>(int8_t* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
#ifdef ENABLE_BF16
template int
loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename);
template int
loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<float, __nv_bfloat16>(float* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<half, __nv_bfloat16>(half* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
template int loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
template int loadWeightFromBinFunc<float, __nv_bfloat16>(float* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
template int loadWeightFromBinFunc<half, __nv_bfloat16>(half* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr,
std::vector<size_t> shape,
std::string filename);
std::string filename,
std::vector<ConcateSlice> slices);
#endif // ENABLE_BF16
template int loadWeightFromBinFunc<int, int>(int* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<int, int>(int* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
#ifdef ENABLE_FP8
template int
loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr, std::vector<size_t> shape, std::string filename);
template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr,
std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices);
#endif // ENABLE_FP8
template<typename T>
int loadWeightFromBin(T* ptr, std::vector<size_t> shape, std::string filename, FtCudaDataType model_file_type)
int loadWeightFromBin(T* ptr,
std::vector<size_t> shape,
std::string filename,
FtCudaDataType model_file_type,
std::vector<ConcateSlice> slices)
{
switch (model_file_type) {
case FtCudaDataType::FP32:
loadWeightFromBinFunc<T, float>(ptr, shape, filename);
loadWeightFromBinFunc<T, float>(ptr, shape, filename, slices);
break;
case FtCudaDataType::FP16:
loadWeightFromBinFunc<T, half>(ptr, shape, filename);
loadWeightFromBinFunc<T, half>(ptr, shape, filename, slices);
break;
case FtCudaDataType::INT8:
loadWeightFromBinFunc<T, int8_t>(ptr, shape, filename);
loadWeightFromBinFunc<T, int8_t>(ptr, shape, filename, slices);
break;
#ifdef ENABLE_BF16
case FtCudaDataType::BF16:
loadWeightFromBinFunc<T, __nv_bfloat16>(ptr, shape, filename);
loadWeightFromBinFunc<T, __nv_bfloat16>(ptr, shape, filename, slices);
break;
#endif
#ifdef ENABLE_FP8
case FtCudaDataType::FP8:
loadWeightFromBinFunc<T, float>(ptr, shape, filename);
loadWeightFromBinFunc<T, float>(ptr, shape, filename, slices);
break;
#endif
default:
......@@ -424,28 +559,50 @@ int loadWeightFromBin(T* ptr, std::vector<size_t> shape, std::string filename, F
}
template<>
int loadWeightFromBin(int* ptr, std::vector<size_t> shape, std::string filename, FtCudaDataType model_file_type)
int loadWeightFromBin(int* ptr,
std::vector<size_t> shape,
std::string filename,
FtCudaDataType model_file_type,
std::vector<ConcateSlice> slices)
{
loadWeightFromBinFunc<int, int>(ptr, shape, filename);
loadWeightFromBinFunc<int, int>(ptr, shape, filename, slices);
return 0;
}
template int
loadWeightFromBin(float* ptr, std::vector<size_t> shape, std::string filename, FtCudaDataType model_file_type);
template int
loadWeightFromBin(half* ptr, std::vector<size_t> shape, std::string filename, FtCudaDataType model_file_type);
template int
loadWeightFromBin(int8_t* ptr, std::vector<size_t> shape, std::string filename, FtCudaDataType model_file_type);
template int loadWeightFromBin(float* ptr,
std::vector<size_t> shape,
std::string filename,
FtCudaDataType model_file_type,
std::vector<ConcateSlice> slices);
template int loadWeightFromBin(half* ptr,
std::vector<size_t> shape,
std::string filename,
FtCudaDataType model_file_type,
std::vector<ConcateSlice> slices);
template int loadWeightFromBin(int8_t* ptr,
std::vector<size_t> shape,
std::string filename,
FtCudaDataType model_file_type,
std::vector<ConcateSlice> slices);
#ifdef ENABLE_BF16
template int
loadWeightFromBin(__nv_bfloat16* ptr, std::vector<size_t> shape, std::string filename, FtCudaDataType model_file_type);
template int loadWeightFromBin(__nv_bfloat16* ptr,
std::vector<size_t> shape,
std::string filename,
FtCudaDataType model_file_type,
std::vector<ConcateSlice> slices);
#endif
#ifdef ENABLE_FP8
template int
loadWeightFromBin(__nv_fp8_e4m3* ptr, std::vector<size_t> shape, std::string filename, FtCudaDataType model_file_type);
template int loadWeightFromBin(__nv_fp8_e4m3* ptr,
std::vector<size_t> shape,
std::string filename,
FtCudaDataType model_file_type,
std::vector<ConcateSlice> slices);
#endif
template int
loadWeightFromBin(int* ptr, std::vector<size_t> shape, std::string filename, FtCudaDataType model_file_type);
template int loadWeightFromBin(int* ptr,
std::vector<size_t> shape,
std::string filename,
FtCudaDataType model_file_type,
std::vector<ConcateSlice> slices);
template<typename T_IN, typename T_OUT>
__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size)
......
......@@ -49,13 +49,20 @@ void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream =
template<typename T>
void cudaRandomUniform(T* buffer, const size_t size);
struct ConcateSlice {
std::vector<std::pair<size_t, size_t>> slices;
};
template<typename T>
int loadWeightFromBin(T* ptr,
std::vector<size_t> shape,
std::string filename,
FtCudaDataType model_file_type = FtCudaDataType::FP32);
FtCudaDataType model_file_type = FtCudaDataType::FP32,
std::vector<ConcateSlice> slices = std::vector<ConcateSlice>());
std::vector<float> loadArrayFromBin(std::vector<size_t> shape, std::string filename);
std::vector<float> loadArrayFromBin(std::vector<size_t> shape,
std::string filename,
std::vector<ConcateSlice> slices = std::vector<ConcateSlice>());
// template<typename T>
// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment