• Israt Nisa's avatar
    [Feature] Gather mm (#3641) · b3d3a2c4
    Israt Nisa authored
    
    
    * init
    
    * init
    
    * working cublasGemm
    
    * benchmark high-mem/low-mem, err gather_mm output
    
    * cuda kernel for bmm like kernel
    
    * removed cpu copy for E_per_Rel
    
    * benchmark code from Minjie
    
    * fixed cublas results in gathermm sorted
    
    * use GPU shared mem in unsorted gather mm
    
    * minor
    
    * Added an optimal version of gather_mm_unsorted
    
    * lint
    
    * init gather_mm_scatter
    
    * cublas transpose added
    
    * fixed h_offset for multiple rel
    
    * backward unittest
    
    * cublas support to transpose W
    
    * adding missed file
    
    * forgot to add header file
    
    * lint
    
    * lint
    
    * cleanup
    
    * lint
    
    * docstring
    
    * lint
    
    * added unittest
    
    * lint
    
    * lint
    
    * unittest
    
    * changed err type
    
    * skip cpu test
    
    * skip CPU code
    
    * move in-len loop inside
    
    * lint
    
    * added check different dim length for B
    
    * w_per_len is optional now
    
    * moved gather_mm to pytorch/backend with backward support
    
    * removed a_/b_trans support
    
    * transpose op inside GEMM call
    
    * removed out alloc from API, changed W 2D to 3D
    
    * Added se_gather_mm, Separate API for sortedE
    
    * Fixed gather_mm (unsorted) user interface
    
    * unsorted gmm backward + separate CAPI for un/sorted A
    
    * typecast to float to support atomicAdd
    
    * lint typecast
    
    * lint
    
    * added gather_mm_scatter
    
    * minor
    
    * const
    
    * design changes
    
    * Added idx_a, idx_b support gmm_scatter
    
    * dgl doc
    
    * lint
    
    * adding gather_mm in ops
    
    * lint
    
    * lint
    
    * minor
    
    * removed benchmark files
    
    * minor
    
    * empty commit
    Co-authored-by: default avatarIsrat Nisa <nisisrat@amazon.com>
    b3d3a2c4
gather_mm.h 3.6 KB