Unverified Commit 7cbfe2ea authored by q.yao's avatar q.yao Committed by GitHub
Browse files

Tensor Parallel python api (#82)

* wip

* profile disable tp

* fix profile

* lint

* fix dlpack

* remove comment

* add tp flag

* add session len check

* add eos

* remove tp and session len inputs

* warp tokenizer

* multithread load weight

* update profile

* refactor tokenizer

* remove pre/post process

* remove mpi4py requirement

* remove

* remove bind

* remove mpi requirement

* check backend_tokenizer
parent 1f88baa5
...@@ -6,10 +6,9 @@ from threading import Thread ...@@ -6,10 +6,9 @@ from threading import Thread
import fire import fire
import numpy as np import numpy as np
from transformers import AutoTokenizer
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
from lmdeploy.turbomind import TurboMind from lmdeploy.turbomind import Tokenizer, TurboMind
def infer(model, session_id: int, input_ids: str, output_seqlen: int, def infer(model, session_id: int, input_ids: str, output_seqlen: int,
...@@ -42,11 +41,7 @@ def infer(model, session_id: int, input_ids: str, output_seqlen: int, ...@@ -42,11 +41,7 @@ def infer(model, session_id: int, input_ids: str, output_seqlen: int,
que.put((session_id, stats)) que.put((session_id, stats))
def warmup(model, def warmup(model, concurrency: int, output_seqlen: int, warmup_round: int = 4):
concurrency: int,
session_len: int,
output_seqlen: int,
warmup_round: int = 4):
print('start to warmup ...') print('start to warmup ...')
def _infer(model, session_id): def _infer(model, session_id):
...@@ -81,18 +76,16 @@ def warmup(model, ...@@ -81,18 +76,16 @@ def warmup(model,
def main(model_path: str, def main(model_path: str,
model_name: str, model_name: str,
concurrency: int = 1, concurrency: int = 1,
session_len: int = 2056,
input_seqlen: int = 0, input_seqlen: int = 0,
output_seqlen: int = 512, output_seqlen: int = 512,
test_round: int = 10): test_round: int = 10):
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_path, tokenizer = Tokenizer(tokenizer_model_path)
trust_remote_code=True)
model = MODELS.get(model_name)() model = MODELS.get(model_name)()
stop_words = model.stop_words stop_words = model.stop_words
tm_model = TurboMind(model_path=model_path, stop_words=stop_words) tm_model = TurboMind(model_path=model_path, stop_words=stop_words)
warmup(tm_model, concurrency, session_len, output_seqlen) warmup(tm_model, concurrency, output_seqlen)
# make up a prompt that can be tokenized into {input_seqlen} tokens # make up a prompt that can be tokenized into {input_seqlen} tokens
prompt = '' if input_seqlen == 0 else 'hi' + ' hi' * (input_seqlen - 1) prompt = '' if input_seqlen == 0 else 'hi' + ' hi' * (input_seqlen - 1)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .tokenizer import Postprocessor, Preprocessor, Tokenizer from .tokenizer import Tokenizer
from .turbomind import TurboMind from .turbomind import TurboMind
__all__ = ['Postprocessor', 'Preprocessor', 'Tokenizer', 'TurboMind'] __all__ = ['Tokenizer', 'TurboMind']
...@@ -4,10 +4,10 @@ import os.path as osp ...@@ -4,10 +4,10 @@ import os.path as osp
import random import random
import fire import fire
from transformers import AutoTokenizer
from lmdeploy import turbomind as tm from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
from lmdeploy.turbomind.tokenizer import Tokenizer
os.environ['TM_LOG_LEVEL'] = 'ERROR' os.environ['TM_LOG_LEVEL'] = 'ERROR'
...@@ -39,12 +39,12 @@ def main(model_name, model_path, session_id: int = 1): ...@@ -39,12 +39,12 @@ def main(model_name, model_path, session_id: int = 1):
session_id (int): the identical id of a session session_id (int): the identical id of a session
""" """
model = MODELS.get(model_name)() model = MODELS.get(model_name)()
tm_model = tm.TurboMind(model_path, stop_words=model.stop_words)
generator = tm_model.create_instance()
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_path, tokenizer = Tokenizer(tokenizer_model_path)
trust_remote_code=True) tm_model = tm.TurboMind(model_path,
model = MODELS.get(model_name)() eos_id=tokenizer.eos_token_id,
stop_words=model.stop_words)
generator = tm_model.create_instance()
nth_round = 1 nth_round = 1
step = 0 step = 0
...@@ -56,7 +56,7 @@ def main(model_name, model_path, session_id: int = 1): ...@@ -56,7 +56,7 @@ def main(model_name, model_path, session_id: int = 1):
exit(0) exit(0)
elif prompt == 'end': elif prompt == 'end':
prompt = model.get_prompt('', nth_round == 1) prompt = model.get_prompt('', nth_round == 1)
input_ids = tokenizer.encode(prompt, add_special_tokens=False) input_ids = tokenizer.encode(prompt)
for outputs in generator.stream_infer(session_id=session_id, for outputs in generator.stream_infer(session_id=session_id,
input_ids=[input_ids], input_ids=[input_ids],
request_output_len=512, request_output_len=512,
...@@ -67,10 +67,14 @@ def main(model_name, model_path, session_id: int = 1): ...@@ -67,10 +67,14 @@ def main(model_name, model_path, session_id: int = 1):
step = 0 step = 0
seed = random.getrandbits(64) seed = random.getrandbits(64)
else: else:
prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
print(f'session {session_id}') print(f'session {session_id}')
print(f'{prompt}', end='', flush=True) if step >= tm_model.session_len:
print('WARNING: exceed session max length.'
' Please end the session.')
continue
prompt = model.get_prompt(prompt, nth_round == 1)
input_ids = tokenizer.encode(prompt)
print(f'{prompt} ', end='', flush=True)
response_size = 0 response_size = 0
for outputs in generator.stream_infer( for outputs in generator.stream_infer(
session_id=session_id, session_id=session_id,
...@@ -89,8 +93,7 @@ def main(model_name, model_path, session_id: int = 1): ...@@ -89,8 +93,7 @@ def main(model_name, model_path, session_id: int = 1):
random_seed=seed if nth_round == 1 else None): random_seed=seed if nth_round == 1 else None):
res, tokens = outputs[0] res, tokens = outputs[0]
# decode res # decode res
response = tokenizer.decode( response = tokenizer.decode(res)[response_size:]
res, skip_special_tokens=True)[response_size:]
response = valid_str(response) response = valid_str(response)
print(f'{response}', end='', flush=True) print(f'{response}', end='', flush=True)
response_size += len(response) response_size += len(response)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
from typing import Sequence, Union from typing import Sequence
import torch import torch
from torch.nn.utils.rnn import pad_sequence
class Tokenizer: class SentencePieceTokenizer:
"""Tokenize prompts or de-tokenize tokens into texts. """Tokenizer of sentencepiece.
Args: Args:
model_file (str): the path of the tokenizer model model_file (str): the path of the tokenizer model
""" """
def __init__(self, model_file: str): def __init__(self, model_file: str):
if model_file.endswith('.model'): from sentencepiece import SentencePieceProcessor
model_folder = osp.split(model_file)[0] self.model = SentencePieceProcessor(model_file=model_file)
else:
model_folder = model_file
tokenizer_config_file = osp.join(model_folder, 'tokenizer_config.json')
model_file_exists = osp.exists(model_file) @property
config_exists = osp.exists(tokenizer_config_file) def vocab_size(self):
use_hf_model = not config_exists or not model_file_exists """vocabulary size."""
return self.model.vocab_size()
self.use_hf_model = use_hf_model
if not self.use_hf_model: @property
from sentencepiece import SentencePieceProcessor def bos_token_id(self):
self.model = SentencePieceProcessor(model_file=model_file) """begine of the sentence token id."""
self.vocab_size = self.model.vocab_size() return self.model.bos_id()
self.bos_token_id = self.model.bos_id()
self.eos_token_id = self.model.eos_id() @property
else: def eos_token_id(self):
from transformers import AutoTokenizer """end of the sentence token id."""
backend_tokenizer_file = osp.join(model_folder, 'tokenizer.json') return self.model.eos_id()
if not osp.exists(backend_tokenizer_file) and model_file_exists:
print('WARNING: Can not find tokenizer.json. '
'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_folder)
self.vocab_size = self.model.vocab_size
self.bos_token_id = self.model.bos_token_id
self.eos_token_id = self.model.eos_token_id
# save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists:
self.model.backend_tokenizer.save(backend_tokenizer_file)
def encode(self, s: str): def encode(self, s: str):
"""Tokenize a prompt. """Tokenize a prompt.
...@@ -53,25 +39,15 @@ class Tokenizer: ...@@ -53,25 +39,15 @@ class Tokenizer:
Returns: Returns:
list[int]: token ids list[int]: token ids
""" """
if not self.use_hf_model: add_bos = False
add_bos = False add_eos = False
add_eos = False if s.find('<BOS>') != -1:
if s.find('<BOS>') != -1: s = s.replace('<BOS>', '')
s = s.replace('<BOS>', '') add_bos = True
add_bos = True if s == '<EOS>':
if s == '<EOS>': s = ''
s = '' add_eos = True
add_eos = True return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
return self.model.Encode(s, add_bos=add_bos, add_eos=add_eos)
else:
add_special_tokens = False
if s.find('<BOS>') != -1:
s = s.replace('<BOS>', '<s>')
if s == '<EOS>':
s = '</s>'
if len(s) == 0:
add_special_tokens = True
return self.model.encode(s, add_special_tokens=add_special_tokens)
def decode(self, t: Sequence[int]): def decode(self, t: Sequence[int]):
"""De-tokenize. """De-tokenize.
...@@ -81,85 +57,132 @@ class Tokenizer: ...@@ -81,85 +57,132 @@ class Tokenizer:
Returns: Returns:
str: text of decoding tokens str: text of decoding tokens
""" """
if not self.use_hf_model: if isinstance(t, torch.Tensor):
return self.model.Decode(t) t = t.tolist()
else: return self.model.Decode(t)
skip_special_tokens = False
return self.model.decode(t,
skip_special_tokens=skip_special_tokens)
class Preprocessor: class HuggingFaceTokenizer:
"""Tokenize prompts. """Tokenizer of sentencepiece.
Args: Args:
tokenizer (Tokenizer): an instance of tokenizer model_dir (str): the directory of the tokenizer model
""" """
def __init__(self, tokenizer: Tokenizer): def __init__(self, model_dir: str):
self.tokenizer = tokenizer from transformers import AutoTokenizer
self.bos_token_id = tokenizer.bos_token_id model_file = osp.join(model_dir, 'tokenizer.model')
self.eos_token_id = tokenizer.eos_token_id backend_tokenizer_file = osp.join(model_dir, 'tokenizer.json')
model_file_exists = osp.exists(model_file)
if not osp.exists(backend_tokenizer_file) and model_file_exists:
print('WARNING: Can not find tokenizer.json. '
'It may take long time to initialize the tokenizer.')
self.model = AutoTokenizer.from_pretrained(model_dir,
trust_remote_code=True)
# save tokenizer.json to reuse
if not osp.exists(backend_tokenizer_file) and model_file_exists:
if hasattr(self.model, 'backend_tokenizer'):
self.model.backend_tokenizer.save(backend_tokenizer_file)
def __call__(self, *args, **kwargs): @property
return self.infer(*args, **kwargs) def vocab_size(self):
"""vocabulary size."""
return self.model.vocab_size
def infer(self, prompts: Union[str, Sequence[str]]) -> tuple: @property
"""Tokenize the input prompts. def bos_token_id(self):
"""begine of the sentence token id."""
return self.model.bos_token_id
Args: @property
prompts(str | Sequence[str]): user's prompt, or a batch prompts def eos_token_id(self):
"""end of the sentence token id."""
return self.model.eos_token_id
def encode(self, s: str):
"""Tokenize a prompt.
Args:
s (str): a prompt
Returns: Returns:
Tuple(torch.Tensor, torch.Tensor): prompt's token list[int]: token ids
ids, ids' length and requested output length
""" """
if isinstance(prompts, str): add_special_tokens = False
_ = [[prompts]] if s.find('<BOS>') != -1:
elif isinstance(prompts, Sequence): s = s.replace('<BOS>', '<s>')
_ = [[prompt] for prompt in prompts] if s == '<EOS>':
else: s = '</s>'
assert 0, f'str or Sequence[str] prompts are expected but got ' \ if len(s) == 0:
f'{type(prompts)}' add_special_tokens = True
return self.model.encode(s, add_special_tokens=add_special_tokens)
start_ids = [ def decode(self, t: Sequence[int]):
torch.IntTensor(self.tokenizer.encode(prompt)) """De-tokenize.
for prompt in prompts
] Args:
start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) t (List[int]): a list of token ids
start_ids = pad_sequence(start_ids, Returns:
batch_first=True, str: text of decoding tokens
padding_value=self.eos_token_id) """
return start_ids, start_lengths skip_special_tokens = True
return self.model.decode(t, skip_special_tokens=skip_special_tokens)
class Postprocessor: class Tokenizer:
"""De-tokenize token ids. """Tokenize prompts or de-tokenize tokens into texts.
Args: Args:
tokenizer (Tokenizer): an instance of tokenizer model_file (str): the path of the tokenizer model
""" """
def __init__(self, tokenizer: Tokenizer): def __init__(self, model_file: str):
self.tokenizer = tokenizer if model_file.endswith('.model'):
self.bos_token_id = tokenizer.bos_token_id model_folder = osp.split(model_file)[0]
self.eos_token_id = tokenizer.eos_token_id else:
model_folder = model_file
model_file = osp.join(model_folder, 'tokenizer.model')
tokenizer_config_file = osp.join(model_folder, 'tokenizer_config.json')
model_file_exists = osp.exists(model_file)
config_exists = osp.exists(tokenizer_config_file)
use_hf_model = config_exists or not model_file_exists
if not use_hf_model:
self.model = SentencePieceTokenizer(model_file)
else:
self.model = HuggingFaceTokenizer(model_folder)
@property
def vocab_size(self):
"""vocabulary size."""
return self.model.vocab_size
def __call__(self, *args, **kwargs): @property
return self.infer(*args, **kwargs) def bos_token_id(self):
"""begine of the sentence token id."""
return self.model.bos_token_id
def infer(self, output_ids: torch.Tensor, seqlen: torch.Tensor): @property
"""De-tokenize tokens for text. def eos_token_id(self):
"""end of the sentence token id."""
return self.model.eos_token_id
def encode(self, s: str):
"""Tokenize a prompt.
Args: Args:
output_ids(torch.Tensor): tokens' id s (str): a prompt
seqlen(torch.Tensor): sequence length Returns:
list[int]: token ids
"""
return self.model.encode(s)
def decode(self, t: Sequence[int]):
"""De-tokenize.
Args:
t (List[int]): a list of token ids
Returns: Returns:
str: decoded tokens str: text of decoding tokens
""" """
outputs = [] return self.model.decode(t)
for tokens, _len in zip(output_ids, seqlen):
output = self.tokenizer.decode(tokens[:_len])
outputs.append(output)
return outputs
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp import os.path as osp
import sys import sys
from configparser import ConfigParser
from contextlib import contextmanager
from queue import Queue from queue import Queue
from threading import Thread from threading import Thread
from typing import Iterable, List from typing import Iterable, List
...@@ -53,47 +55,80 @@ def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap): ...@@ -53,47 +55,80 @@ def _tm_dict_to_torch_dict(tm_dict: _tm.TensorMap):
return ret return ret
@contextmanager
def cuda_ctx(device_id):
old_device = torch.cuda.current_device()
torch.cuda.set_device(device_id)
yield
torch.cuda.set_device(old_device)
class TurboMind: class TurboMind:
"""LMDeploy's inference engine. """LMDeploy's inference engine.
Args: Args:
model_path (str): the path of turbomind's model model_path (str): the path of turbomind's model
data_type (str): the data type data_type (str): the data type
session_len (int): the max length of a session
eos_id (int): eos token id eos_id (int): eos token id
stop_words (List[int]): token ids of stop-words stop_words (List[int]): token ids of stop-words
device_id (int): the id of a gpu card
node_id (int): the id of a node
device_num (int): the number of gpu cards
node_num (int): the number of node
""" """
def __init__(self, def __init__(self,
model_path: str, model_path: str,
data_type: str = 'fp16', data_type: str = 'fp16',
session_len: int = 2048,
eos_id: int = 2, eos_id: int = 2,
stop_words: List[int] = None, stop_words: List[int] = None):
device_id: int = 0,
node_id: int = 0,
device_num: int = 1,
node_num: int = 1):
self.eos_id = eos_id self.eos_id = eos_id
# create model instance # TODO: support mpi
node_id = 0
node_num = 1
# read meta from model path
self.gpu_count = 1
self.session_len = 2048
ini_path = osp.join(model_path, 'triton_models/weights/config.ini')
with open(ini_path, 'r') as f:
parser = ConfigParser()
parser.read_file(f)
section_name = ''
if 'turbomind' in parser:
section_name = 'turbomind'
elif 'llama' in parser:
section_name = 'llama'
if len(section_name) > 0:
self.gpu_count = parser.getint(section_name,
'tensor_para_size')
self.session_len = parser.getint(section_name, 'session_len')
# params
self.node_id = node_id self.node_id = node_id
self.node_num = node_num self.node_num = node_num
self.gpu_count = device_num
self.device_id = device_id
self.world_size = self.node_num * self.gpu_count self.world_size = self.node_num * self.gpu_count
self.rank = self.node_id * self.gpu_count + self.device_id
self.session_len = session_len
# create model
weight_dir = osp.join(model_path, 'triton_models', 'weights') weight_dir = osp.join(model_path, 'triton_models', 'weights')
model = _tm.AbstractTransformerModel.create_llama_model( model = _tm.AbstractTransformerModel.create_llama_model(
weight_dir, tensor_para_size=self.gpu_count, data_type=data_type) weight_dir, tensor_para_size=self.gpu_count, data_type=data_type)
model.create_shared_weights(self.device_id, self.rank)
self.model = model self.model = model
self.nccl_params = model.create_nccl_params(self.node_id)
torch.cuda.synchronize()
# create weight
def _create_weight(device_id):
with cuda_ctx(device_id):
rank = self.node_id * self.gpu_count + device_id
model.create_shared_weights(device_id, rank)
threads = []
for device_id in range(self.gpu_count):
t = Thread(target=_create_weight, args=(device_id, ))
t.start()
threads.append(t)
for t in threads:
t.join()
self.stop_words = _stop_words(stop_words) self.stop_words = _stop_words(stop_words)
def create_instance(self, cuda_stream_id=0): def create_instance(self, cuda_stream_id=0):
...@@ -117,40 +152,57 @@ class TurboMindInstance: ...@@ -117,40 +152,57 @@ class TurboMindInstance:
def __init__(self, tm_model, cuda_stream_id=0): def __init__(self, tm_model, cuda_stream_id=0):
self.tm_model = tm_model self.tm_model = tm_model
self.cuda_stream_id = cuda_stream_id
self.node_id = tm_model.node_id
self.gpu_count = tm_model.gpu_count
self.device_id = tm_model.device_id
self.rank = tm_model.rank
self.stop_words = tm_model.stop_words self.stop_words = tm_model.stop_words
self.eos_id = tm_model.eos_id self.eos_id = tm_model.eos_id
self.session_len = tm_model.session_len self.session_len = tm_model.session_len
self.cuda_stream_id = cuda_stream_id
# create instance self.nccl_params = tm_model.nccl_params
model = tm_model.model self.instance_comm = tm_model.model.create_instance_comm(
nccl_params = model.create_nccl_params(tm_model.node_id) self.gpu_count)
custom_comms = model.create_custom_comms(tm_model.world_size)
instance_comm = model.create_instance_comm(tm_model.gpu_count) # create model instances
model_insts = [None] * self.gpu_count
model_inst = model.create_model_instance(self.device_id, self.rank, threads = []
self.cuda_stream_id, for device_id in range(self.gpu_count):
nccl_params, custom_comms[0]) t = Thread(target=self._create_model_instance,
# model_inst.register_callback(self._forward_callback) args=(device_id, model_insts))
self.model_inst = model_inst t.start()
self.instance_comm = instance_comm threads.append(t)
for t in threads:
t.join()
self.model_insts = model_insts
self.que = Queue() self.que = Queue()
self.thread = None self.threads = [None] * self.gpu_count
def _create_model_instance(self, device_id, model_insts):
with cuda_ctx(device_id):
rank = self.node_id * self.gpu_count + device_id
model_inst = self.tm_model.model.create_model_instance(
device_id, rank, self.cuda_stream_id, self.nccl_params)
model_insts[device_id] = model_inst
def _forward_callback(self, result, ctx): def _forward_callback(self, result, ctx):
self.que.put((False, result)) self.que.put((False, result))
def _forward_thread(self, inputs): def _forward_thread(self, inputs):
def _func(): def _func(device_id, enque_output):
output = self.model_inst.forward(inputs, self.instance_comm) with cuda_ctx(device_id):
self.que.put((True, output)) output = self.model_insts[device_id].forward(
inputs, self.instance_comm)
if enque_output:
self.que.put((True, output))
self.thread = Thread(target=_func) for device_id in range(self.gpu_count):
self.thread.start() t = Thread(target=_func, args=(device_id, device_id == 0))
t.start()
self.threads[device_id] = t
def stream_infer(self, def stream_infer(self,
session_id, session_id,
...@@ -190,7 +242,7 @@ class TurboMindInstance: ...@@ -190,7 +242,7 @@ class TurboMindInstance:
stream_output (bool): indicator for stream output stream_output (bool): indicator for stream output
""" """
if stream_output: if stream_output:
self.model_inst.register_callback(self._forward_callback) self.model_insts[0].register_callback(self._forward_callback)
if len(input_ids) == 0: if len(input_ids) == 0:
input_ids = [] input_ids = []
...@@ -281,10 +333,11 @@ class TurboMindInstance: ...@@ -281,10 +333,11 @@ class TurboMindInstance:
for output, l in zip(output_ids, sequence_length)] for output, l in zip(output_ids, sequence_length)]
if finish: if finish:
for t in self.threads:
t.join()
while self.que.qsize() > 0: while self.que.qsize() > 0:
self.que.get() self.que.get()
self.thread.join()
break break
if stream_output: if stream_output:
self.model_inst.unregister_callback() self.model_insts[0].unregister_callback()
...@@ -14,7 +14,7 @@ endif() ...@@ -14,7 +14,7 @@ endif()
pybind11_add_module(${PROJECT_NAME} bind.cpp) pybind11_add_module(${PROJECT_NAME} bind.cpp)
target_link_libraries(${PROJECT_NAME} PRIVATE TransformerTritonBackend target_link_libraries(${PROJECT_NAME} PRIVATE TransformerTritonBackend
LlamaTritonBackend custom_ar_comm memory_utils) LlamaTritonBackend)
target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_14) target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_14)
set_target_properties(${PROJECT_NAME} PROPERTIES set_target_properties(${PROJECT_NAME} PROPERTIES
......
#include "src/turbomind/python/dlpack.h" #include "src/turbomind/python/dlpack.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h" #include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include "src/turbomind/utils/nccl_utils.h"
#include <cuda_runtime.h>
#include <memory> #include <memory>
#include <pybind11/functional.h> #include <pybind11/functional.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
...@@ -26,7 +28,14 @@ std::shared_ptr<T> make_shared_nodel(T data) ...@@ -26,7 +28,14 @@ std::shared_ptr<T> make_shared_nodel(T data)
DLDevice getDLDevice(triton::Tensor& tensor) DLDevice getDLDevice(triton::Tensor& tensor)
{ {
DLDevice device{.device_id = 0}; int device_id = 0;
if (tensor.where == triton::MEMORY_GPU) {
cudaPointerAttributes ptr_attr;
cudaPointerGetAttributes(&ptr_attr, tensor.data);
device_id = ptr_attr.device;
}
DLDevice device{.device_id = device_id};
switch (tensor.where) { switch (tensor.where) {
case triton::MEMORY_CPU: case triton::MEMORY_CPU:
...@@ -204,7 +213,6 @@ std::shared_ptr<triton::Tensor> DLManagedTensorToTritonTensor(DLManagedTensor* t ...@@ -204,7 +213,6 @@ std::shared_ptr<triton::Tensor> DLManagedTensorToTritonTensor(DLManagedTensor* t
PYBIND11_MODULE(_turbomind, m) PYBIND11_MODULE(_turbomind, m)
{ {
// nccl param // nccl param
py::class_<ft::NcclParam>(m, "NcclParam") py::class_<ft::NcclParam>(m, "NcclParam")
.def(py::init<int, int>(), "rank"_a = 0, "world_size"_a = 1) .def(py::init<int, int>(), "rank"_a = 0, "world_size"_a = 1)
...@@ -320,7 +328,6 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -320,7 +328,6 @@ PYBIND11_MODULE(_turbomind, m)
// transformer model // transformer model
py::class_<AbstractTransformerModel, std::shared_ptr<AbstractTransformerModel>>(m, "AbstractTransformerModel") py::class_<AbstractTransformerModel, std::shared_ptr<AbstractTransformerModel>>(m, "AbstractTransformerModel")
// .def_static("create_llama_model", &AbstractTransformerModel::createLlamaModel, "model_dir"_a)
.def_static( .def_static(
"create_llama_model", "create_llama_model",
[](std::string model_dir, [](std::string model_dir,
...@@ -349,7 +356,7 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -349,7 +356,7 @@ PYBIND11_MODULE(_turbomind, m)
"multi_node"_a = false) "multi_node"_a = false)
.def( .def(
"create_custom_comms", "create_custom_comms",
[](std::shared_ptr<AbstractTransformerModel>& model, int world_size) { [](AbstractTransformerModel* model, int world_size) {
std::vector<std::shared_ptr<ft::AbstractCustomComm>> ret; std::vector<std::shared_ptr<ft::AbstractCustomComm>> ret;
model->createCustomComms(&ret, world_size); model->createCustomComms(&ret, world_size);
return ret; return ret;
...@@ -358,7 +365,7 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -358,7 +365,7 @@ PYBIND11_MODULE(_turbomind, m)
.def("create_instance_comm", &AbstractTransformerModel::createInstanceComm, "size"_a) .def("create_instance_comm", &AbstractTransformerModel::createInstanceComm, "size"_a)
.def( .def(
"create_model_instance", "create_model_instance",
[](std::shared_ptr<AbstractTransformerModel>& model, [](AbstractTransformerModel* model,
int deviceId, int deviceId,
int rank, int rank,
long stream_id, long stream_id,
...@@ -367,12 +374,17 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -367,12 +374,17 @@ PYBIND11_MODULE(_turbomind, m)
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id); cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_id);
return model->createModelInstance(deviceId, rank, stream, nccl_params, custom_all_reduce_comm); return model->createModelInstance(deviceId, rank, stream, nccl_params, custom_all_reduce_comm);
}, },
py::call_guard<py::gil_scoped_release>(),
"device_id"_a, "device_id"_a,
"rank"_a, "rank"_a,
"stream"_a, "stream"_a,
"nccl_params"_a, "nccl_params"_a,
"custom_all_reduce_comm"_a = nullptr) "custom_all_reduce_comm"_a = nullptr)
.def("create_shared_weights", &AbstractTransformerModel::createSharedWeights, "device_id"_a, "rank"_a) .def("create_shared_weights",
&AbstractTransformerModel::createSharedWeights,
py::call_guard<py::gil_scoped_release>(),
"device_id"_a,
"rank"_a)
.def("__str__", &AbstractTransformerModel::toString) .def("__str__", &AbstractTransformerModel::toString)
.def("__repr__", &AbstractTransformerModel::toString) .def("__repr__", &AbstractTransformerModel::toString)
.def("get_tensor_para_size", &AbstractTransformerModel::getTensorParaSize) .def("get_tensor_para_size", &AbstractTransformerModel::getTensorParaSize)
......
...@@ -283,7 +283,7 @@ export(PACKAGE TritonTurboMindBackend) ...@@ -283,7 +283,7 @@ export(PACKAGE TritonTurboMindBackend)
# limitations under the License. # limitations under the License.
add_library(TransformerTritonBackend SHARED transformer_triton_backend.cpp) add_library(TransformerTritonBackend SHARED transformer_triton_backend.cpp)
target_link_libraries(TransformerTritonBackend PRIVATE nccl_utils mpi_utils) target_link_libraries(TransformerTritonBackend PRIVATE nccl_utils)
install(TARGETS TransformerTritonBackend DESTINATION ${CMAKE_INSTALL_LIBDIR}) install(TARGETS TransformerTritonBackend DESTINATION ${CMAKE_INSTALL_LIBDIR})
add_subdirectory(llama) add_subdirectory(llama)
...@@ -39,9 +39,6 @@ AbstractTransformerModel::createNcclParams(const int node_id, const int device_i ...@@ -39,9 +39,6 @@ AbstractTransformerModel::createNcclParams(const int node_id, const int device_i
ft::ftNcclGetUniqueId(nccl_ids[i]); ft::ftNcclGetUniqueId(nccl_ids[i]);
} }
} }
for (size_t i = 0; i < nccl_ids.size(); i++) {
ft::mpi::bcast(&nccl_ids[i], sizeof(nccl_ids[i]), ft::mpi::MPI_TYPE_BYTE, 0, ft::mpi::COMM_WORLD);
}
} }
std::vector<ft::NcclParam> tensor_para_params(local_comm_size); std::vector<ft::NcclParam> tensor_para_params(local_comm_size);
......
...@@ -29,7 +29,6 @@ ...@@ -29,7 +29,6 @@
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/custom_ar_comm.h" #include "src/turbomind/utils/custom_ar_comm.h"
#include "src/turbomind/utils/instance_comm.h" #include "src/turbomind/utils/instance_comm.h"
#include "src/turbomind/utils/mpi_utils.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
namespace ft = turbomind; namespace ft = turbomind;
......
...@@ -64,7 +64,7 @@ add_library(nccl_utils STATIC nccl_utils.cc) ...@@ -64,7 +64,7 @@ add_library(nccl_utils STATIC nccl_utils.cc)
set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET nccl_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) set_property(TARGET nccl_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
if (BUILD_MULTI_GPU) if (BUILD_MULTI_GPU)
target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} mpi_utils logger) target_link_libraries(nccl_utils PUBLIC ${NCCL_LIBRARIES} logger)
endif() endif()
add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc) add_library(cublasINT8MMWrapper STATIC cublasINT8MMWrapper.cc)
......
...@@ -306,118 +306,6 @@ void ftNcclParamDestroy(NcclParam& param) ...@@ -306,118 +306,6 @@ void ftNcclParamDestroy(NcclParam& param)
#endif #endif
} }
void ftNcclInitialize(NcclParam& tensor_para,
NcclParam& pipeline_para,
const int tensor_para_size,
const int pipeline_para_size)
{
TM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
// Initialize nccl communication grid of tensor and pipeline parallel groups.
#ifndef BUILD_MULTI_GPU
FT_CHECK_WITH_INFO(tensor_para_size == 1,
fmtstr("tensor_para_size=%d although BUILD_MULTI_GPU is disabled. "
"Please use the cmake flag -DBUILD_MULTI_GPU=ON if you want "
"to use tensor/pipeline parallelism.",
tensor_para_size));
FT_CHECK_WITH_INFO(pipeline_para_size == 1,
fmtstr("pipeline_para_size=%d although BUILD_MULTI_GPU is disabled. "
"Please use the cmake flag -DBUILD_MULTI_GPU=ON if you want "
"to use tensor/pipeline parallelism.",
pipeline_para_size));
tensor_para.rank_ = 0;
tensor_para.world_size_ = tensor_para_size;
pipeline_para.rank_ = 0;
pipeline_para.world_size_ = pipeline_para_size;
#else
// Initialize a nccl communicator.
if (tensor_para.nccl_comm_ != nullptr && pipeline_para.nccl_comm_ != nullptr) {
TM_LOG_WARNING("NcclParam is already initialized. Skip NCCL initialization.");
return;
}
FT_CHECK(tensor_para.nccl_comm_ == nullptr);
FT_CHECK(pipeline_para.nccl_comm_ == nullptr);
FT_CHECK(tensor_para_size > 0);
FT_CHECK(pipeline_para_size > 0);
if (tensor_para_size == 1 && pipeline_para_size == 1) {
TM_LOG_WARNING("Skip NCCL initialization since requested tensor/pipeline parallel sizes are equals to 1.");
tensor_para.rank_ = 0;
tensor_para.world_size_ = tensor_para_size;
pipeline_para.rank_ = 0;
pipeline_para.world_size_ = pipeline_para_size;
return;
}
int mpi_initialized;
MPICHECK(MPI_Initialized(&mpi_initialized));
FT_CHECK_WITH_INFO(mpi_initialized, "Fail to nccl initialization because MPI is not initialized.");
int rank, world_size;
MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &rank));
MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &world_size));
FT_CHECK_WITH_INFO(tensor_para_size * pipeline_para_size <= world_size,
fmtstr("tensor_para_size (%d) * pipeline_para_size (%d) should equal to the world size (%d).",
tensor_para_size,
pipeline_para_size,
world_size));
// Convert WORLD communicator into 2D grid (k * n) communicator.
// row = a tensor parallel group, col = a pipeline parallel group.
MPI_Comm grid_comm, tp_comm, pp_comm;
int dims[2] = {pipeline_para_size, tensor_para_size};
int periods[2] = {0, 0};
MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 0, &grid_comm);
// Split 2D communicator into rows and cols.
int tp_remain_dims[2] = {false, true};
int pp_remain_dims[2] = {true, false};
MPI_Cart_sub(grid_comm, tp_remain_dims, &tp_comm);
MPI_Cart_sub(grid_comm, pp_remain_dims, &pp_comm);
int tp_rank, pp_rank;
MPI_Comm_rank(tp_comm, &tp_rank);
MPI_Comm_rank(pp_comm, &pp_rank);
ncclUniqueId tp_uid;
ncclUniqueId pp_uid;
// The root of each group creates a nccl uid.
if (tp_rank == 0) {
TM_LOG_DEBUG("rank %d pp rank %d creates nccl uid.", rank, tp_rank);
NCCLCHECK(ncclGetUniqueId(&tp_uid));
}
if (pp_rank == 0) {
TM_LOG_DEBUG("rank %d pp rank %d creates nccl uid.", rank, pp_rank);
NCCLCHECK(ncclGetUniqueId(&pp_uid));
}
// Broadcast nccl uid to share the same nccl uid across gpus in the same group.
TM_LOG_DEBUG("Broadcast nccl uid to the others in the same parallel groups.");
MPI_Bcast(&tp_uid, sizeof(tp_uid), MPI_BYTE, 0, tp_comm);
MPI_Bcast(&pp_uid, sizeof(pp_uid), MPI_BYTE, 0, pp_comm);
TM_LOG_DEBUG("Initialize NCCL communicators.");
ncclComm_t tp_nccl_comm, pp_nccl_comm;
NCCLCHECK(ncclCommInitRank(&tp_nccl_comm, tensor_para_size, tp_uid, tp_rank));
NCCLCHECK(ncclCommInitRank(&pp_nccl_comm, pipeline_para_size, pp_uid, pp_rank));
tensor_para.world_size_ = tensor_para_size;
tensor_para.rank_ = tp_rank;
tensor_para.nccl_uid_ = tp_uid;
tensor_para.nccl_comm_ = tp_nccl_comm;
pipeline_para.world_size_ = pipeline_para_size;
pipeline_para.rank_ = pp_rank;
pipeline_para.nccl_uid_ = pp_uid;
pipeline_para.nccl_comm_ = pp_nccl_comm;
TM_LOG_INFO("NCCL initialized rank=%d world_size=%d tensor_para=%s pipeline_para=%s",
rank,
world_size,
tensor_para.toString().c_str(),
pipeline_para.toString().c_str());
#endif
TM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
static std::atomic<int>& ncclGroupCount() static std::atomic<int>& ncclGroupCount()
{ {
static std::atomic<int> value{}; static std::atomic<int> value{};
......
...@@ -18,11 +18,9 @@ ...@@ -18,11 +18,9 @@
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/mpi_utils.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#ifdef BUILD_MULTI_GPU #ifdef BUILD_MULTI_GPU
#include <mpi.h>
#include <nccl.h> #include <nccl.h>
#endif #endif
#include <stdio.h> #include <stdio.h>
...@@ -118,11 +116,6 @@ void ftNcclGetUniqueId(NcclUid& uid); ...@@ -118,11 +116,6 @@ void ftNcclGetUniqueId(NcclUid& uid);
void ftNcclCommInitRank(NcclParam& param, const int rank, const int world_size, const NcclUid uid); void ftNcclCommInitRank(NcclParam& param, const int rank, const int world_size, const NcclUid uid);
void ftNcclParamDestroy(NcclParam& param); void ftNcclParamDestroy(NcclParam& param);
void ftNcclInitialize(NcclParam& tensor_para,
NcclParam& pipeline_para,
const int tensor_para_size,
const int pipeline_para_size);
int ftNcclNextGroupId(); int ftNcclNextGroupId();
int ftNcclGroupCount(); int ftNcclGroupCount();
......
from lmdeploy.turbomind.tokenizer import Postprocessor, Preprocessor, Tokenizer from lmdeploy.turbomind.tokenizer import Tokenizer
def main(): def main():
tokenizer = Tokenizer('huggyllama/llama-7b') tokenizer = Tokenizer('huggyllama/llama-7b')
preprocessor = Preprocessor(tokenizer)
postprocessor = Postprocessor(tokenizer)
prompts = ['cest la vie', '上帝已死'] prompts = ['cest la vie', '上帝已死']
tokens = preprocessor(prompts) for prompt in prompts:
print(tokens) tokens = tokenizer.encode(prompt)
output = tokenizer.decode(tokens)
decode_prompts = postprocessor(*tokens) print(output)
print(decode_prompts)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment