Unverified Commit 62e0fa9a authored by lvhan028's avatar lvhan028 Committed by GitHub
Browse files

support 'input_tokens' in triton_example (#49)

* check-in script for tokenizing a file

* use max_input_len
parent 4c303b17
...@@ -5,3 +5,4 @@ __pycache__/ ...@@ -5,3 +5,4 @@ __pycache__/
workspace/ workspace/
.cache .cache
*build*/ *build*/
examples/cpp/llama/*.csv
...@@ -8,6 +8,7 @@ model_dir=/workspace/models/triton_models/weights/ ...@@ -8,6 +8,7 @@ model_dir=/workspace/models/triton_models/weights/
[request] [request]
request_batch_size=8 request_batch_size=8
max_input_len=1
request_output_len=2048 request_output_len=2048
beam_width=1 ; beam width for beam search beam_width=1 ; beam width for beam search
top_k=1 ; k value for top k sampling top_k=1 ; k value for top k sampling
......
...@@ -244,7 +244,7 @@ broadCastRequest(const std::vector<int>& v_start_ids, ...@@ -244,7 +244,7 @@ broadCastRequest(const std::vector<int>& v_start_ids,
int read_start_ids(size_t batch_size, int read_start_ids(size_t batch_size,
std::vector<int>* v_start_lengths, std::vector<int>* v_start_lengths,
std::vector<int>* v_start_ids, std::vector<int>* v_start_ids,
size_t& max_input_len, size_t max_input_len,
const int end_id, const int end_id,
const int beam_width, const int beam_width,
std::string file_name); std::string file_name);
...@@ -263,11 +263,11 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std ...@@ -263,11 +263,11 @@ prepareRequest(std::string ini_name, const int node_id, const int gpu_count, std
const int start_id = reader.GetInteger("request", "start_id"); const int start_id = reader.GetInteger("request", "start_id");
const int end_id = reader.GetInteger("request", "end_id"); const int end_id = reader.GetInteger("request", "end_id");
const int max_input_len = reader.GetInteger("request", "max_input_len");
std::vector<int> v_start_ids; std::vector<int> v_start_ids;
std::vector<int> v_start_lengths; std::vector<int> v_start_lengths;
size_t max_input_len = 0;
read_start_ids(request_batch_size, read_start_ids(request_batch_size,
&v_start_lengths, &v_start_lengths,
&v_start_ids, &v_start_ids,
...@@ -427,6 +427,7 @@ int main(int argc, char* argv[]) ...@@ -427,6 +427,7 @@ int main(int argc, char* argv[])
const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0]; const int batch_size = output_tensors_lists[0].get()->at("output_ids").shape[0];
const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1]; const int beam_width = output_tensors_lists[0].get()->at("output_ids").shape[1];
const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2]; const int seq_len = output_tensors_lists[0].get()->at("output_ids").shape[2];
std::vector<int> seq_lens(batch_size);
// step 6: check results // step 6: check results
if (node_id == 0) { if (node_id == 0) {
std::string fName = "out"; std::string fName = "out";
...@@ -439,7 +440,6 @@ int main(int argc, char* argv[]) ...@@ -439,7 +440,6 @@ int main(int argc, char* argv[])
// int* hBuf = new int[outCount]; // int* hBuf = new int[outCount];
std::vector<int> hBuf(outCount); std::vector<int> hBuf(outCount);
ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount); ft::cudaD2Hcpy(hBuf.data(), d_output_ids, outCount);
std::vector<int> seq_lens(batch_size);
ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size); ft::cudaD2Hcpy(seq_lens.data(), d_seq_lens, batch_size);
std::cout << "sequence length: "; std::cout << "sequence length: ";
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
...@@ -503,7 +503,7 @@ int main(int argc, char* argv[]) ...@@ -503,7 +503,7 @@ int main(int argc, char* argv[])
" FT-CPP-GPT-Triton-time %.2f ms\n", " FT-CPP-GPT-Triton-time %.2f ms\n",
batch_size, batch_size,
beam_width, beam_width,
seq_len, seq_lens[0],
((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite); ((end.tv_sec - start.tv_sec) * 1000 + (end.tv_usec - start.tv_usec) * 0.001) / ite);
} }
...@@ -516,7 +516,7 @@ int main(int argc, char* argv[]) ...@@ -516,7 +516,7 @@ int main(int argc, char* argv[])
int read_start_ids(size_t batch_size, int read_start_ids(size_t batch_size,
std::vector<int>* v_start_lengths, std::vector<int>* v_start_lengths,
std::vector<int>* v_start_ids, std::vector<int>* v_start_ids,
size_t& max_input_len, size_t max_input_len,
const int end_id, const int end_id,
const int beam_width, const int beam_width,
std::string file_name) std::string file_name)
...@@ -531,14 +531,14 @@ int read_start_ids(size_t batch_size, ...@@ -531,14 +531,14 @@ int read_start_ids(size_t batch_size,
while (std::getline(start_id_file, line)) { while (std::getline(start_id_file, line)) {
std::stringstream lineStream(line); std::stringstream lineStream(line);
std::string vals; std::string vals;
int i1 = 0;
std::vector<int> tmp_vec; std::vector<int> tmp_vec;
while (std::getline(lineStream, vals, ',')) { while (std::getline(lineStream, vals, ',')) {
tmp_vec.push_back(std::stoi(vals)); tmp_vec.push_back(std::stoi(vals));
i1++; if (tmp_vec.size() == max_input_len)
break;
} }
tmp_start_ids.push_back(tmp_vec); tmp_start_ids.push_back(tmp_vec);
tmp_start_lengths.push_back(i1); tmp_start_lengths.push_back(tmp_vec.size());
line_num++; line_num++;
} }
if (batch_size == 0) { if (batch_size == 0) {
...@@ -551,19 +551,6 @@ int read_start_ids(size_t batch_size, ...@@ -551,19 +551,6 @@ int read_start_ids(size_t batch_size,
return 0; return 0;
} }
max_input_len = tmp_start_lengths.data()[0];
for (uint i = 1; i < (uint)tmp_start_lengths.size(); i++) {
max_input_len = max_input_len > tmp_start_lengths.data()[i] ? max_input_len : tmp_start_lengths.data()[i];
}
while ((int)tmp_start_lengths.size() < batch_size) {
std::vector<int> padding_ids;
for (int i = 0; i < max_input_len; i++) {
padding_ids.push_back(end_id);
}
tmp_start_ids.push_back(padding_ids);
tmp_start_lengths.push_back(max_input_len);
}
// Add padding // Add padding
for (int i = 0; i < (int)tmp_start_ids.size(); i++) { for (int i = 0; i < (int)tmp_start_ids.size(); i++) {
...@@ -572,6 +559,12 @@ int read_start_ids(size_t batch_size, ...@@ -572,6 +559,12 @@ int read_start_ids(size_t batch_size,
} }
} }
// Pad to batch_size
for (int i = (int)tmp_start_lengths.size(); i < batch_size; i++) {
tmp_start_ids.push_back(tmp_start_ids[0]);
tmp_start_lengths.push_back(tmp_start_lengths[0]);
}
for (int i = 0; i < (int)tmp_start_ids.size(); i++) { for (int i = 0; i < (int)tmp_start_ids.size(); i++) {
for (int b = 0; b < beam_width; b++) { for (int b = 0; b < beam_width; b++) {
for (int j = 0; j < (int)tmp_start_ids[i].size(); j++) { for (int j = 0; j < (int)tmp_start_ids[i].size(); j++) {
......
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,44883,2282,32901,4220,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,46088,46064,625,19880,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,47335,56437,60468,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,44883,2282,6828,3467,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,36589,3467,7849,299,7032,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,44976,39798,6828,3467,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,2795,977,9193,299,405,537,46323,13,44975,45004,11130,32843,45004,35597
0,18396,22305,13,4662,561,399,326,44875,29913,6938,1198,345,3134,39407,320,47997,45778,45121,61969,47371,492,13,44872,65616,47997,45778,45121,61969,47371,345,263,13820,1558,5515,2404,409,345,12643,521,41109,34993,326,44875,24488,10677,320,45691,45926,45513,46641,47641,46285,6456,492,824,345,12314,307,377,11951,44863,23391,44863,329,5420,935,421,44858,13,44872,65616,47997,45778,45121,61969,47371,541,2914,329,34352,30302,3530,299,278,5515,14966,521,278,1711,1591,425,5716,329,65616,45452,45545,44858,13,570,996,372,13,44975,45004,44950,11111,45004,35597,45691,45926,45513,46641,47641,46285,6456,46323,13,44975,45004,11130,32843,45004,35597
import os.path as osp
from typing import List from typing import List
import fire import fire
...@@ -21,11 +22,6 @@ class Tokenizer: ...@@ -21,11 +22,6 @@ class Tokenizer:
self.end_id = self.model.eos_token_id self.end_id = self.model.eos_token_id
self.pad_id = self.model.pad_token_id self.pad_id = self.model.pad_token_id
print(f'vocab_size = {self.vocab_size}')
print(f'start_id = {self.start_id}')
print(f'end_id = {self.end_id}')
print(f'pad_id = {self.pad_id}')
def encode(self, s: str): def encode(self, s: str):
if hasattr(self.model, 'Encode'): if hasattr(self.model, 'Encode'):
return self.model.Encode(s, add_bos=True) return self.model.Encode(s, add_bos=True)
...@@ -46,10 +42,21 @@ def main(model_file: str = '/data/llama/model/tokenizer.model', ...@@ -46,10 +42,21 @@ def main(model_file: str = '/data/llama/model/tokenizer.model',
if encode_file: if encode_file:
with open(encode_file, 'r') as f: with open(encode_file, 'r') as f:
xs = tokenizer.encode(f.read()) xs = tokenizer.encode(f.read())
print(','.join(map(str, xs))) xs = ','.join(map(str, xs))
print(xs)
output_dir = osp.dirname(osp.abspath(__file__))
with open(osp.join(output_dir, 'start_ids.csv'), 'w') as f:
f.write(xs)
elif decode_file: elif decode_file:
with open(decode_file, 'r') as f: with open(decode_file, 'r') as f:
ys = tokenizer.decode(f.read()) token_ids = f.read()
token_ids = token_ids.splitlines()
for _token_ids in token_ids:
_token_ids = _token_ids.split(',')
_token_ids = [int(token_id) for token_id in _token_ids]
ys = tokenizer.decode(_token_ids)
print(ys) print(ys)
else: else:
first = True first = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment