1. Vector-vector multiplication => no temp memory, forward memory cost is 1 element (could be neglected), backward memory cost is the same
as two input vectors.
2. Matrix-vector multiplication => if the first input is matrix, no temp memory is needed, otherwise, there is a temp memory in the backward
phase for the transpose of the matrix. The forward memory cost is the size of output tensor, backward memory cost is the size of the two inputs; if
the first input is vector, the forward memory cost is the size of the output tensor, and during the backward phase, it will allocate a temp memory
the same size as the input matrix, and allocate memory for the gradient of two inputs.
3. Batched Matrix-vector multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of
output tensor, backward memory cost is the size of the two inputs; if the second input is the batched matrix, the matmul will allocate memory for
the gradient of the batched matrix in the forward phase (as they create a new tensor without the former batches), so the forward memory cost is
the output tensor and the newly created matrix (take the same amount of memory of the input batched matrix). During the backward phase, it will
allocate a temp memory the same size as input batched matrix, and allocate a tensor for the gradient of the input vector. The gradient of the batched
matrix will be stored in the memory allocated during the forward phase.
3. Matrix-matrix multiplication => no temp memory, forward memory is the size of output tensor, backward memory is the size of the two inputs
4. Batched matrix-matrix multiplication => if the first input is the batched matrix, no temp memory, the forward memory cost is the size of two
inputs and backward memory cost is the size of the output tensor; if the second input is the batched matrix, during the forward phase it will allocate
memory for the output and gradient of the second input, and has a temp memory the same size as the output, during the backward phase, it
will allocate memory for the gradient of the first input and has a temp memory which is as big as output and the second input.
5. Batched matrix-batched matrix multiplication => if the two inputs have the same batch dimensions, no temp memory, the forward memory cost is the size
of output, backward memory cost is the size of the two inputs; it the two inputs have different batch dimensions, during the forward phase it will allocate
memory of the expanded inputs (so that the batch dimensions could match) and the output, and during the backward phase, it has a temp memory of the size of
two expanded inputs, and it will allocate memory for the gradient of the two inputs and discard the expanded inputs allocated during the forward phase.
Returns:
Tuple[TrainCycleItem, TrainCycleItem, bool]: compute cost, memory cost and forward inputs