"include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "1d53731f2c210557caab5660dbe2c578dce6114f"
Commit fabfa2e2 authored by Pan Zezhong's avatar Pan Zezhong
Browse files

use int64 for random sample

parent b0010cbc
......@@ -135,8 +135,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di}, stream);
auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, stream);
auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, stream);
auto result_buf = Tensor::buffer(INFINI_DTYPE_U32, {nreq}, stream);
auto result_cpu = std::vector<uint32_t>(nreq);
auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, stream);
auto result_cpu = std::vector<int64_t>(nreq);
// Prepare inputs
auto batch_pos_ids = std::vector<uint32_t>(ntok);
size_t req_start = 0;
......@@ -270,7 +271,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
infiniopRandomSampleDescriptor_t desc_sample;
RUN_INFINI(infiniopCreateRandomSampleDescriptor(
rsrc.handle, &desc_sample,
TensorDesc::create(INFINI_DTYPE_U32, {}, {})->get(),
TensorDesc::create(INFINI_DTYPE_I64, {}, {})->get(),
TensorDesc::create(dt_logits, {dvoc}, {1})->get()));
RUN_INFINI(infiniopGetRandomSampleWorkspaceSize(desc_sample, &temp_size));
workspace_size = std::max(workspace_size, temp_size);
......@@ -398,7 +399,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
}
RUN_INFINI(infinirtStreamSynchronize(stream));
RUN_INFINI(infinirtMemcpy(result_cpu.data(), result_buf->data(),
sizeof(uint32_t) * nreq, INFINIRT_MEMCPY_D2H));
sizeof(int64_t) * nreq, INFINIRT_MEMCPY_D2H));
for (uint32_t req = 0; req < nreq; req++) {
ans[req] = result_cpu[req];
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment