Commit fd107062 authored by letaoqin's avatar letaoqin
Browse files

fix no bias examples

parent f90af872
......@@ -597,8 +597,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // p_acc0_biases;
{}, // p_acc1_biases;
nullptr, // p_acc0_biases;
nullptr, // p_acc1_biases;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -686,8 +688,8 @@ int run(int argc, char* argv[])
static_cast<InputDataType*>(y_device_buf.GetDeviceBuffer()),
static_cast<ZDataType*>(z_fwd_device_buf.GetDeviceBuffer()),
static_cast<LSEDataType*>(lse_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......@@ -743,8 +745,10 @@ int run(int argc, char* argv[])
static_cast<OutputDataType*>(qgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(kgrad_device_buf.GetDeviceBuffer()),
static_cast<OutputDataType*>(vgrad_device_buf.GetDeviceBuffer()),
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
nullptr, // p_acc0_bias;
nullptr, // p_acc1_bias;
nullptr,
nullptr,
q_gs_ms_ks_lengths,
q_gs_ms_ks_strides,
k_gs_ns_ks_lengths,
......
......@@ -604,6 +604,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......@@ -650,6 +652,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs,
QKVElementOp{},
QKVElementOp{},
......
......@@ -728,6 +728,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs_bwd,
QKVElementOp{},
QKVElementOp{},
......@@ -815,6 +817,8 @@ int run(int argc, char* argv[])
p_vgrad,
{}, // std::array<void*, 1> p_acc0_biases;
{}, // std::array<void*, 1> p_acc1_biases;
{},
{},
problem_descs_bwd,
QKVElementOp{},
QKVElementOp{},
......
......@@ -25,7 +25,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define USING_MASK 0
#define DIM 64 // DIM should be a multiple of 8.
#define DIM 128 // DIM should be a multiple of 8.
#include <iostream>
#include <numeric>
......@@ -284,7 +284,7 @@ int run(int argc, char* argv[])
bool input_permute = false;
bool output_permute = false;
float p_drop = 0.9;
float p_drop = 0.0;
const unsigned long long seed = 1;
const unsigned long long offset = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment