Commit fbd45579 authored by PanZezhong's avatar PanZezhong
Browse files

issue/210 support random sampling with random number

parent 3747f7f3
......@@ -19,7 +19,8 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
has_job_(false),
job_done_(false),
should_exit_(false),
init_done_(false) {
init_done_(false),
rng_(std::random_device{}()) {
if (cache_config != nullptr) {
pending_cache_config_ = cache_config->unique_copy();
}
......@@ -252,7 +253,6 @@ void RankWorker::thread_loop() {
auto temperature{local_args.temperature};
auto top_p{local_args.top_p};
auto top_k{local_args.top_k};
auto random_val{local_args.random_val};
const auto &logits_shape{logits->shape()};
const auto &vocab_size{logits_shape[2]};
......@@ -267,6 +267,7 @@ void RankWorker::thread_loop() {
for (auto i{decltype(n_req)(0)}; i < n_req; ++i) {
auto score{logits->view({batch_size * total_len, vocab_size})->narrow({{0, size_t(input_offsets[i + 1] - 1), 1}})->view({vocab_size})};
auto out{output_ids->narrow({{0, i, 1}})->view({})};
float random_val = std::uniform_real_distribution<float>(0, 1)(rng_);
infinicore::op::random_sample_(
out, score, random_val, top_p, top_k, temperature);
}
......
......@@ -7,6 +7,7 @@
#include <any>
#include <condition_variable>
#include <mutex>
#include <random>
#include <string>
#include <thread>
#include <vector>
......@@ -45,8 +46,6 @@ public:
float top_p{1};
float random_val{0.1};
infinilm::InfinilmModel::Input to_model_input(infinicore::Device device) const;
};
......@@ -114,6 +113,9 @@ private:
std::thread thread_;
std::mutex mutex_;
std::condition_variable cv_;
// Random
std::mt19937 rng_;
};
} // namespace infinilm::engine
......@@ -179,6 +179,27 @@ def get_args():
action="store_true",
help="skip loading model weights",
)
parser.add_argument(
"--top-k",
type=int,
default=1,
help="top k sampling",
)
parser.add_argument(
"--top-p",
type=float,
default=1.0,
help="top p sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="sampling temperature",
)
return parser.parse_args()
......@@ -247,6 +268,9 @@ class TestModel:
batch_size: int,
input_len: int,
output_len: int,
top_k=1,
top_p=1.0,
temperature=1.0,
):
input_ids = repeat_prompt(self.input_ids_list[0], target_length=input_len)
input_ids_list = [input_ids] * batch_size
......@@ -260,7 +284,13 @@ class TestModel:
print("=================== start generate ====================")
output_ids = self.model.generate(
input_ids_infini,
GenerationConfig(max_new_tokens=output_len, eos_token_id=[]),
GenerationConfig(
max_new_tokens=output_len,
eos_token_id=[],
top_k=top_k,
top_p=top_p,
temperature=temperature,
),
_measure_and_log_time=True,
)
t2 = time.time()
......@@ -349,4 +379,7 @@ if __name__ == "__main__":
batch_size=batch_size,
input_len=input_len,
output_len=output_len,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
)
......@@ -89,6 +89,27 @@ def get_args():
help="use paged cache",
)
parser.add_argument(
"--top-k",
type=int,
default=1,
help="top k sampling",
)
parser.add_argument(
"--top-p",
type=float,
default=1.0,
help="top p sampling",
)
parser.add_argument(
"--temperature",
type=float,
default=1.0,
help="sampling temperature",
)
return parser.parse_args()
......@@ -99,6 +120,9 @@ def test(
infini_device=infinicore.device("cpu", 0),
tp=1,
enable_paged_attn=False,
top_k=1,
top_p=1.0,
temperature=1.0,
):
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
......@@ -186,7 +210,10 @@ def test(
output_ids = model.generate(
input_ids_infini,
GenerationConfig(
max_new_tokens=max_new_tokens, temperature=1, top_k=1, top_p=0.8
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
),
_measure_and_log_time=True,
)
......@@ -243,4 +270,7 @@ if __name__ == "__main__":
infini_device=infini_device,
tp=tp,
enable_paged_attn=enable_paged_attn,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment