dnnl_helper.cpp 21.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <list>
#include <optional>

#include "common/memory_desc.hpp"
#include "common/memory.hpp"

#include "dnnl_helper.h"

static dnnl::engine& default_engine() {
  static dnnl::engine engine(dnnl::engine::kind::cpu, 0);
  return engine;
}

static dnnl::stream& default_stream() {
  static dnnl::stream stream(default_engine());
  return stream;
}

void release_dnnl_matmul_handler(int64_t handler) {
  DNNLMatMulPrimitiveHandler* ptr =
      reinterpret_cast<DNNLMatMulPrimitiveHandler*>(handler);
  delete ptr;
}

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
DNNLScratchPadManager::DNNLScratchPadManager() : size_(0), ptr_(nullptr) {
  this->realloc(allocation_unit * 128);
}

void DNNLScratchPadManager::realloc(size_t new_size) {
  new_size = round(new_size);
  if (new_size > size_) {
    ptr_ = std::aligned_alloc(64, new_size);
    size_ = new_size;
  }
}

DNNLScratchPadManager* DNNLScratchPadManager::get_dnnl_scratchpad_manager() {
  static DNNLScratchPadManager manager;
  return &manager;
}

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
template <typename KT, typename VT>
class DNNLPrimitiveCache {
 public:
  using cache_value_t = std::pair<KT, VT>;
  using result_value_t = VT;
  using container_t = std::list<cache_value_t>;
  using value_iterator_t = typename container_t::iterator;
  using map_t = std::unordered_map<KT, value_iterator_t>;
  using creator_t = VT (*)();

 public:
  DNNLPrimitiveCache(size_t capacity)
      : capacity_(capacity),
        values_(),
        key_to_value_(std::min(256lu, capacity)) {
    assert(capacity > 0);
  }

  template <typename F>
  result_value_t get_or_create(const KT& key, F&& creator) {
    std::optional<value_iterator_t> value = get_value(key);
    if (value.has_value()) {
      return value.value()->second;
    } else {
      return add_value({key, creator()})->second;
    }
  }

  size_t size() const { return values_.size(); }

 private:
  void dump_data() {
    std::stringstream ss;
    ss << "table_id: " << std::hex << reinterpret_cast<size_t>(this) << std::dec
       << "\n";
    ss << "container: [";
    for (auto&& iter : values_) {
      ss << "(" << iter.first << ", " << std::hex
         << reinterpret_cast<size_t>(iter.second.get()) << "), " << std::dec;
    }
    ss << "]\n";

    ss << "map: [";
    for (auto&& iter : key_to_value_) {
      ss << "(" << iter.first << ", " << iter.second->first << ", " << std::hex
         << reinterpret_cast<size_t>(iter.second->second.get()) << std::dec
         << "), ";
    }
    ss << "]\n";
    std::printf("%s\n", ss.str().c_str());
  }

  value_iterator_t add_value(cache_value_t&& new_value) {
    if (size() == capacity_) {
      cache_value_t& last_item = values_.back();
      key_to_value_.erase(last_item.first);
      values_.pop_back();
    }

    auto& added_value_ = values_.emplace_front(std::move(new_value));
    key_to_value_.emplace(added_value_.first, values_.begin());
    return values_.begin();
  }

  std::optional<value_iterator_t> get_value(const KT& key) {
    if (key_to_value_.size() > 0 && key == values_.begin()->first) {
      return values_.begin();
    }

    auto value_map_iterator = key_to_value_.find(key);
    if (value_map_iterator != key_to_value_.end()) {
      values_.splice(values_.begin(), values_, value_map_iterator->second);
      return value_map_iterator->second;
    } else {
      return {};
    }
  }

 private:
  const size_t capacity_;
  container_t values_;
  map_t key_to_value_;
};

DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
    const Args& args, dnnl::memory::data_type b_type)
    : b_n_size_(args.b_n_size),
      b_n_stride_(args.b_n_stride),
      b_k_size_(args.b_k_size),
      b_k_stride_(args.b_k_stride),
      b_type_(b_type),
      c_type_(args.c_type),
      runtime_memory_ptrs_(8),
      primitive_cache_size_(args.primitive_cache_size) {
  assert(primitive_cache_size_ > 0);
}

void DNNLMatMulPrimitiveHandler::prepack_weight(
140
141
    void* original_b_ptr, dnnl::memory::desc original_b_md,
    dnnl::memory::desc b_target_mem_desc) {
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
  dnnl::memory original_weight(original_b_md, default_engine(), original_b_ptr);
  dnnl::memory packed_weight(b_target_mem_desc, default_engine());
  {
    dnnl::reorder(original_weight, packed_weight)
        .execute(default_stream(), original_weight, packed_weight);
    default_stream().wait();
  }
  memory_cache_[DNNL_ARG_WEIGHTS] = packed_weight;
  b_target_mem_desc_ = b_target_mem_desc;
}

void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
    size_t index, dnnl_memory* memory_ptr) {
  dnnl::impl::memory_storage_t* mem_storage_ptr = memory_ptr->memory_storage();
  dnnl_memory_desc* mem_desc = const_cast<dnnl_memory_desc*>(memory_ptr->md());
  runtime_memory_ptrs_[index] = {mem_storage_ptr, mem_desc};
}

std::pair<dnnl::impl::memory_storage_t*, dnnl_memory_desc*>
DNNLMatMulPrimitiveHandler::get_runtime_memory_ptr(size_t index) {
  return runtime_memory_ptrs_[index];
}

namespace std {
template <>
struct hash<W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey> {
  size_t operator()(
      const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
    return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size) ^
           hash<int>()(static_cast<int>(val.a_qs)) ^
           hash<int>()(static_cast<int>(val.b_qs)) ^ hash<bool>()(val.use_azp) ^
           hash<int>()(static_cast<int>(val.c_type));
  }
};

template <>
struct hash<W8A8MatMulPrimitiveHandler::MSizeCacheKey> {
  size_t operator()(
      const W8A8MatMulPrimitiveHandler::MSizeCacheKey& val) const {
    return hash<dnnl_dim_t>()(val.a_m_size) ^ hash<bool>()(val.use_bias) ^
           hash<int>()(static_cast<int>(val.bias_type));
  }
};
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

template <>
struct hash<MatMulPrimitiveHandler::ClassMatmulCacheKey> {
  size_t operator()(
      const MatMulPrimitiveHandler::ClassMatmulCacheKey& val) const {
    return hash<dnnl_dim_t>()(val.b_n_size) ^ hash<dnnl_dim_t>()(val.b_k_size);
  }
};

template <>
struct hash<MatMulPrimitiveHandler::MSizeCacheKey> {
  size_t operator()(const MatMulPrimitiveHandler::MSizeCacheKey& val) const {
    return hash<dnnl_dim_t>()(val.a_m_size) ^
           hash<dnnl_dim_t>()(val.a_m_stride) ^ hash<bool>()(val.use_bias) ^
           hash<int>()(static_cast<int>(val.bias_type));
  }
};
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
}  // namespace std

bool operator==(const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
                const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
  return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size &&
         l.a_qs == r.a_qs && l.b_qs == r.b_qs && l.use_azp == r.use_azp &&
         l.c_type == r.c_type;
}

bool operator==(const W8A8MatMulPrimitiveHandler::MSizeCacheKey& l,
                const W8A8MatMulPrimitiveHandler::MSizeCacheKey& r) {
  return l.use_bias == r.use_bias && l.a_m_size == r.a_m_size &&
         l.bias_type == r.bias_type;
}

217
218
219
220
221
222
223
224
225
226
227
bool operator==(const MatMulPrimitiveHandler::ClassMatmulCacheKey& l,
                const MatMulPrimitiveHandler::ClassMatmulCacheKey& r) {
  return l.b_n_size == r.b_n_size && l.b_k_size == r.b_k_size;
}

bool operator==(const MatMulPrimitiveHandler::MSizeCacheKey& l,
                const MatMulPrimitiveHandler::MSizeCacheKey& r) {
  return l.a_m_size == r.a_m_size && l.a_m_stride == r.a_m_stride &&
         l.use_bias == r.use_bias && l.bias_type == r.bias_type;
}

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
static std::shared_ptr<W8A8MatMulPrimitiveHandler::MSizeCache>
get_w8a8_class_primitive_cache(
    const W8A8MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
    int64_t cache_size) {
  static W8A8MatMulPrimitiveHandler::ClassMatmulCache cache(128);
  assert(cache_size > 0);
  return cache.get_or_create(key, [&]() {
    return std::make_shared<W8A8MatMulPrimitiveHandler::MSizeCache>(cache_size);
  });
}

W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args)
    : DNNLMatMulPrimitiveHandler(
          static_cast<const DNNLMatMulPrimitiveHandler::Args&>(args),
          dnnl::memory::data_type::s8),
      use_azp_(args.use_a_zero_point),
      a_qs_(args.a_quantization_strategy),
      b_qs_(args.b_quantization_strategy),
      m_size_cache_(nullptr) {
  assert(a_qs_ != QuantizationStrategy::PER_OUTPUT_CHANNEL);
  assert(b_qs_ != QuantizationStrategy::PER_TOKEN);
  if (a_qs_ == QuantizationStrategy::PER_TOKEN) {
    assert(!use_azp_);
  };
252
253
254
  dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
                                   {b_k_stride_, b_n_stride_});
  prepack_weight(args.b_ptr, original_b_md,
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
                 create_primitive_desc(
                     MSizeCacheKey{.a_m_size = DNNL_RUNTIME_DIM_VAL,
                                   .use_bias = false,
                                   .bias_type = dnnl::memory::data_type::undef},
                     true)
                     .weights_desc());
  init_runtime_memory_cache(args);
}

void W8A8MatMulPrimitiveHandler::execute(ExecArgs& args) {
  auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
  auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
  a_storage->set_data_handle((void*)args.a_ptr);
  a_mem_desc->dims[0] = args.a_m_size;
  c_storage->set_data_handle((void*)args.c_ptr);
  c_mem_desc->dims[0] = args.a_m_size;

  if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
    auto&& [a_scale_storage, a_scale_mem_desc] = get_runtime_memory_ptr(2);
    a_scale_storage->set_data_handle((void*)args.a_scales_ptr);
  }
  if (use_azp_) {
    auto&& [a_zero_point_storage, a_zero_point_mem_desc] =
        get_runtime_memory_ptr(3);
    a_zero_point_storage->set_data_handle((void*)args.a_zero_points_ptr);
  }

  if (args.use_bias) {
    auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(4);
    bias_storage->set_data_handle((void*)args.bias_ptr);
  }

  dnnl::matmul matmul = get_matmul_cache(args);
288
289
290
291
292

  auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(5);
  scratchpad_storage->set_data_handle(
      DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
  matmul.execute(default_stream(), memory_cache_);
  default_stream().wait();
}

dnnl::matmul W8A8MatMulPrimitiveHandler::get_matmul_cache(
    const MSizeCacheKey& key) {
  if (m_size_cache_.get() == nullptr) {
    ClassMatmulCacheKey key = {.b_n_size = b_n_size_,
                               .b_k_size = b_k_size_,
                               .a_qs = a_qs_,
                               .b_qs = b_qs_,
                               .use_azp = use_azp_,
                               .c_type = c_type_};
    m_size_cache_ = get_w8a8_class_primitive_cache(key, primitive_cache_size_);
  }

  return m_size_cache_->get_or_create(key, [&]() {
    dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
311
312
    auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
    manager->realloc(desc.scratchpad_desc().get_size());
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    return dnnl::matmul(desc);
  });
}

void W8A8MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
  memory_cache_[DNNL_ARG_SRC] = dnnl::memory({{1, b_k_size_},
                                              dnnl::memory::data_type::s8,
                                              dnnl::memory::format_tag::ab},
                                             default_engine(), nullptr);
  set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
  memory_cache_[DNNL_ARG_DST] =
      dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
                   default_engine(), nullptr);
  set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());

  // For PER_TOKEN, scales will be applied in outside epilogue
  if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
    memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC] = dnnl::memory(
        {{1}, dnnl::memory::data_type::f32, {1}}, default_engine(), nullptr);
    set_runtime_memory_ptr(
        2, memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC].get());
    if (use_azp_) {
      memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = dnnl::memory(
          {{1}, dnnl::memory::data_type::s32, {1}}, default_engine(), nullptr);
      set_runtime_memory_ptr(
          3, memory_cache_[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC].get());
    }
  }

  if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
    memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
        dnnl::memory({{1}, dnnl::memory::data_type::f32, {1}}, default_engine(),
                     (void*)args.b_scales_ptr);
  } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
    memory_cache_[DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS] =
        dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
                     default_engine(), (void*)args.b_scales_ptr);
  }

  memory_cache_[DNNL_ARG_BIAS] =
      dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
                   default_engine(), nullptr);
  set_runtime_memory_ptr(4, memory_cache_[DNNL_ARG_BIAS].get());
356
357
358
359
360

  memory_cache_[DNNL_ARG_SCRATCHPAD] =
      dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
                   default_engine(), nullptr);
  set_runtime_memory_ptr(5, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
}

dnnl::matmul::primitive_desc W8A8MatMulPrimitiveHandler::create_primitive_desc(
    const MSizeCacheKey& key, bool first_time) {
  dnnl::memory::desc a_md({key.a_m_size, b_k_size_},
                          dnnl::memory::data_type::s8,
                          dnnl::memory::format_tag::ab);
  dnnl::memory::desc b_md;
  if (first_time) {
    b_md =
        dnnl::memory::desc({b_k_size_, b_n_size_}, dnnl::memory::data_type::s8,
                           dnnl::memory::format_tag::any);
  } else {
    b_md = b_target_mem_desc_;
  }
  dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
                          dnnl::memory::format_tag::ab);

  dnnl::primitive_attr attr;
380
381
382

  attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);

383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
  // For PER_TOKEN, scales will be applied in outside epilogue
  if (a_qs_ == QuantizationStrategy::PER_TENSOR) {
    attr.set_scales_mask(DNNL_ARG_SRC, 0);
    if (use_azp_) {
      attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
    }
  }

  if (b_qs_ == QuantizationStrategy::PER_TENSOR) {
    attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
  } else if (b_qs_ == QuantizationStrategy::PER_OUTPUT_CHANNEL) {
    attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2);
  }

  if (key.use_bias) {
    // For PER_TOKEN, bias will be applied in epilogue
    assert(a_qs_ == QuantizationStrategy::PER_TENSOR);
    dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
    return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
                                        c_md, attr);
  } else {
    return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
                                        attr);
  }
}
408
409
410
411
412
413
414
415

MatMulPrimitiveHandler::MatMulPrimitiveHandler(const Args& args)
    : DNNLMatMulPrimitiveHandler(
          static_cast<DNNLMatMulPrimitiveHandler::Args>(args), args.ab_type),
      m_size_cache_(nullptr) {
  assert(ab_type_ == dnnl::memory::data_type::f32 ||
         ab_type_ == dnnl::memory::data_type::bf16 ||
         ab_type_ == dnnl::memory::data_type::f16);
416
417
418
419
420

  dnnl::memory::desc original_b_md({b_k_size_, b_n_size_}, b_type_,
                                   {b_k_stride_, b_n_stride_});

  prepack_weight(args.b_ptr, original_b_md,
421
                 create_primitive_desc(
422
423
424
425
426
427
428
429
430
431
432
433
434
                     MSizeCacheKey{
#ifdef VLLM_USE_ACL
                         // Arm Compute Library (ACL) backend for oneDNN does
                         // not support runtime
                         // dimensions, so we set M to a default value
                         .a_m_size = 128,
                         .a_m_stride = b_k_size_,
#else
                         .a_m_size = DNNL_RUNTIME_DIM_VAL,
                         .a_m_stride = DNNL_RUNTIME_DIM_VAL,
#endif
                         .use_bias = false,
                         .bias_type = dnnl::memory::data_type::undef},
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
                     true)
                     .weights_desc());
  init_runtime_memory_cache(args);
}

static std::shared_ptr<MatMulPrimitiveHandler::MSizeCache>
get_matul_class_primitive_cache(
    const MatMulPrimitiveHandler::ClassMatmulCacheKey& key,
    int64_t cache_size) {
  static MatMulPrimitiveHandler::ClassMatmulCache cache(128);
  assert(cache_size > 0);
  return cache.get_or_create(key, [&]() {
    return std::make_shared<MatMulPrimitiveHandler::MSizeCache>(cache_size);
  });
}

void MatMulPrimitiveHandler::execute(ExecArgs& args) {
  auto&& [a_storage, a_mem_desc] = get_runtime_memory_ptr(0);
  auto&& [c_storage, c_mem_desc] = get_runtime_memory_ptr(1);
  a_storage->set_data_handle((void*)args.a_ptr);
  a_mem_desc->dims[0] = args.a_m_size;
  a_mem_desc->format_desc.blocking.strides[0] = args.a_m_stride;
  c_storage->set_data_handle((void*)args.c_ptr);
  c_mem_desc->dims[0] = args.a_m_size;

460
461
462
463
#ifndef VLLM_USE_ACL
  // We do not support in ACL backend of oneDNN, we handle bias by:
  // 1. copying it into the result tensor
  // 2. attaching a fused-sum post-op to the matmul primitive
464
465
466
467
  if (args.use_bias) {
    auto&& [bias_storage, bias_mem_desc] = get_runtime_memory_ptr(2);
    bias_storage->set_data_handle((void*)args.bias_ptr);
  }
468
#endif
469
470
  dnnl::matmul matmul = get_matmul_cache(args);

471
472
473
474
475
476
477
478
479
480
481
482
483
484
// With ACL backend of oneDNN, the required memory format might change when the
// source tensor dims change. This does not really happen in practice, so isn't
// a performance hit, but we need to support it because the API allows for it.
#ifdef VLLM_USE_ACL
  auto new_expected_wei_desc =
      dnnl::matmul::primitive_desc(
          const_cast<dnnl_primitive_desc_t>(matmul.get_primitive_desc()))
          .weights_desc();
  if (new_expected_wei_desc != b_target_mem_desc_) {
    prepack_weight(memory_cache_[DNNL_ARG_WEIGHTS].get_data_handle(),
                   b_target_mem_desc_, new_expected_wei_desc);
  }
#endif

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
  auto&& [scratchpad_storage, scratchpad_mem_desc] = get_runtime_memory_ptr(3);
  scratchpad_storage->set_data_handle(
      DNNLScratchPadManager::get_dnnl_scratchpad_manager()->get_data<void>());

  matmul.execute(default_stream(), memory_cache_);
  default_stream().wait();
}

dnnl::matmul MatMulPrimitiveHandler::get_matmul_cache(
    const MSizeCacheKey& key) {
  if (m_size_cache_.get() == nullptr) {
    ClassMatmulCacheKey key = {.b_n_size = b_n_size_, .b_k_size = b_k_size_};
    m_size_cache_ = get_matul_class_primitive_cache(key, primitive_cache_size_);
  }
  return m_size_cache_->get_or_create(key, [&]() {
    dnnl::matmul::primitive_desc desc = this->create_primitive_desc(key, false);
    auto manager = DNNLScratchPadManager::get_dnnl_scratchpad_manager();
    manager->realloc(desc.scratchpad_desc().get_size());
    return dnnl::matmul(desc);
  });
}

dnnl::matmul::primitive_desc MatMulPrimitiveHandler::create_primitive_desc(
    const MSizeCacheKey& key, bool first_time) {
  dnnl::memory::desc a_md;
  dnnl::memory::desc b_md;
  if (first_time) {
    a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
                              dnnl::memory::format_tag::ab);
    b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
                              dnnl::memory::format_tag::any);
  } else {
    a_md = dnnl::memory::desc({key.a_m_size, b_k_size_}, b_type_,
                              {key.a_m_stride, 1});
519
520
521
522
523
#ifdef VLLM_USE_ACL
    // ACL's backend of oneDNN always expects the weight format to be "any"
    b_md = dnnl::memory::desc({b_k_size_, b_n_size_}, b_type_,
                              dnnl::memory::format_tag::any);
#else
524
    b_md = b_target_mem_desc_;
525
#endif
526
527
528
529
530
531
532
533
534
  }
  dnnl::memory::desc c_md({key.a_m_size, b_n_size_}, c_type_,
                          dnnl::memory::format_tag::ab);

  dnnl::primitive_attr attr;
  attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);

  if (key.use_bias) {
    dnnl::memory::desc bias_md({1, b_n_size_}, key.bias_type, {b_n_size_, 1});
535
536
537
538
539
540
541
542
543
// Since ACL's matmuls don't support passing a bias_md, we apply the bias
// through a fused-sum post-op
#ifdef VLLM_USE_ACL
    dnnl::post_ops post_ops;
    post_ops.append_sum();
    attr.set_post_ops(post_ops);
    return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
                                        attr);
#else
544
545
    return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, bias_md,
                                        c_md, attr);
546
#endif
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
  } else {
    return dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, c_md,
                                        attr);
  }
}

void MatMulPrimitiveHandler::init_runtime_memory_cache(const Args& args) {
  memory_cache_[DNNL_ARG_SRC] = dnnl::memory(
      {{1, b_k_size_}, b_type_, {b_k_size_, 1}}, default_engine(), nullptr);
  set_runtime_memory_ptr(0, memory_cache_[DNNL_ARG_SRC].get());
  memory_cache_[DNNL_ARG_DST] =
      dnnl::memory({{1, b_n_size_}, c_type_, dnnl::memory::format_tag::ab},
                   default_engine(), nullptr);
  set_runtime_memory_ptr(1, memory_cache_[DNNL_ARG_DST].get());

562
563
// ACL matmuls don't support bias_md, so we don't need these
#ifndef VLLM_USE_ACL
564
565
566
567
  memory_cache_[DNNL_ARG_BIAS] =
      dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
                   default_engine(), nullptr);
  set_runtime_memory_ptr(2, memory_cache_[DNNL_ARG_BIAS].get());
568
#endif
569
570
571
572
573
  memory_cache_[DNNL_ARG_SCRATCHPAD] =
      dnnl::memory({{b_n_size_}, dnnl::memory::data_type::f32, {1}},
                   default_engine(), nullptr);
  set_runtime_memory_ptr(3, memory_cache_[DNNL_ARG_SCRATCHPAD].get());
}
574
575
576
577
578
579
580
581

bool is_onednn_acl_supported() {
#ifdef VLLM_USE_ACL
  return true;
#else
  return false;
#endif
}