Commit 1661828b authored by yan.yan's avatar yan.yan
Browse files

fix small bug

parent 48c8434d
...@@ -229,9 +229,10 @@ int main(int argc, char **argv) { ...@@ -229,9 +229,10 @@ int main(int argc, char **argv) {
weights, pair, indices_kernel_num, arch, out_features_real.dim(0), weights, pair, indices_kernel_num, arch, out_features_real.dim(0),
inverse, is_subm, inverse, is_subm,
static_cast<int>(tv::gemm::SparseConvAlgo::kNative), static_cast<int>(tv::gemm::SparseConvAlgo::kNative),
reinterpret_cast<std::uintptr_t>(stream), bias, 1.0, reinterpret_cast<std::uintptr_t>(stream), bias,
/*bias alpha, only used for leaky relu*/, 0.0, 1.0
tv::gemm::Activation::kReLU); /*bias alpha, only used for leaky relu*/,
0.0, tv::gemm::Activation::kReLU);
} else { } else {
// regular conv use numbers in indices_kernel_num to perform gemm // regular conv use numbers in indices_kernel_num to perform gemm
// so we don't need to slice. // so we don't need to slice.
...@@ -265,10 +266,10 @@ int main(int argc, char **argv) { ...@@ -265,10 +266,10 @@ int main(int argc, char **argv) {
// only regular conv need direct table. // only regular conv need direct table.
bool use_direct_table = direct_table && !is_subm; bool use_direct_table = direct_table && !is_subm;
auto max_act_out_theory = SpconvOps::get_handcrafted_max_act_out( auto max_act_out_theory = SpconvOps::get_handcrafted_max_act_out(
real_num_act_in, ksize, stride, padding, dilation); static_num_act_in, ksize, stride, padding, dilation);
// query workspace size. // query workspace size.
int workspace_size = SpconvOps::get_indice_gen_workspace_size( int workspace_size = SpconvOps::get_indice_gen_workspace_size(
KV, real_num_act_in, out_inds_num_limit, max_act_out_theory, KV, static_num_act_in, out_inds_num_limit, max_act_out_theory,
is_subm, use_int64_hash_k, use_direct_table); is_subm, use_int64_hash_k, use_direct_table);
// you should return workspace size in tensorrt plugin method. // you should return workspace size in tensorrt plugin method.
tv::Tensor workspace = tv::empty({workspace_size}, tv::uint8, 0); tv::Tensor workspace = tv::empty({workspace_size}, tv::uint8, 0);
...@@ -278,9 +279,9 @@ int main(int argc, char **argv) { ...@@ -278,9 +279,9 @@ int main(int argc, char **argv) {
// dynamic allocator, in c++ (inference engine) we need to use // dynamic allocator, in c++ (inference engine) we need to use
// fixed-size workspace and create a static allocator. // fixed-size workspace and create a static allocator.
auto ws_tensors = SpconvOps::get_indice_gen_tensors_from_workspace( auto ws_tensors = SpconvOps::get_indice_gen_tensors_from_workspace(
workspace.raw_data(), KV, real_num_act_in, workspace.raw_data(), KV, static_num_act_in,
is_subm ? real_num_act_in : out_inds_num_limit, max_act_out_theory, is_subm ? static_num_act_in : out_inds_num_limit,
is_subm, use_int64_hash_k, use_direct_table); max_act_out_theory, is_subm, use_int64_hash_k, use_direct_table);
// pair can also have a upper bound. // pair can also have a upper bound.
// !!!!!IMPORTANT!!!!!!! if you provide a static (padded) pair_fwd and // !!!!!IMPORTANT!!!!!!! if you provide a static (padded) pair_fwd and
// other indice data, the output layout is tight pair_fwd_correct = // other indice data, the output layout is tight pair_fwd_correct =
...@@ -288,7 +289,7 @@ int main(int argc, char **argv) { ...@@ -288,7 +289,7 @@ int main(int argc, char **argv) {
// real_pair_size) this valid for pair_fwd, pair_bwd, pair_mask_fwd, // real_pair_size) this valid for pair_fwd, pair_bwd, pair_mask_fwd,
// pair_mask_bwd, mask_argsort_fwd, mask_argsort_bwd. // pair_mask_bwd, mask_argsort_fwd, mask_argsort_bwd.
int pair_fwd_size_padded = int pair_fwd_size_padded =
is_subm ? real_num_act_in : out_inds_num_limit; is_subm ? static_num_act_in : out_inds_num_limit;
tv::Tensor pair_fwd_padded = tv::Tensor pair_fwd_padded =
tv::empty({KV, pair_fwd_size_padded}, tv::int32, 0); tv::empty({KV, pair_fwd_size_padded}, tv::int32, 0);
// you can find equivalent python code of following code in python // you can find equivalent python code of following code in python
...@@ -373,7 +374,7 @@ int main(int argc, char **argv) { ...@@ -373,7 +374,7 @@ int main(int argc, char **argv) {
int num_act_out_real = std::get<1>(pair_res); int num_act_out_real = std::get<1>(pair_res);
tv::Tensor out_features = tv::Tensor out_features =
tv::empty({is_subm ? real_num_act_in : out_inds_num_limit, 64}, tv::empty({is_subm ? static_num_act_in : out_inds_num_limit, 64},
tv::float16, 0); tv::float16, 0);
auto out_features_real = auto out_features_real =
out_features.slice_first_axis(0, num_act_out_real); out_features.slice_first_axis(0, num_act_out_real);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment